diff --git a/.coveragerc b/.coveragerc deleted file mode 100644 index 8d3149f6..00000000 --- a/.coveragerc +++ /dev/null @@ -1,3 +0,0 @@ -[run] -parallel = True -source = mycli diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs index e69de29b..40790347 100644 --- a/.git-blame-ignore-revs +++ b/.git-blame-ignore-revs @@ -0,0 +1,6 @@ +# rename "toolkit" to "ptoolkit" +d891e5ae670c44b96ecd79fca36da91748d8c44a +4c8b5bfb314ed8998b8652613f0669fe58537a62 +# gather pytest files in a subdirectory +9dbad2c5be3786eacbb127362b9b37f41b4d4785 +eb4526c66f141734b806d4782968135772f39557 diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md new file mode 100644 index 00000000..e0c58148 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -0,0 +1,31 @@ +--- +name: Bug report +about: Create a report to help us improve +title: '' +labels: '' +assignees: '' + +--- + + + +### Suggested troubleshooting steps for bug reports + + * [ ] Upgraded to the latest mycli if possible. + * [ ] Ran `mycli --checkup`, if supported. + +### Expected Behavior + + +### Actual Behavior + + +### Steps to Reproduce + + +### System + + * mycli version: + * OS/version: + +### Discussion diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md new file mode 100644 index 00000000..e46a4c01 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -0,0 +1,10 @@ +--- +name: Feature request +about: Suggest an idea for this project +title: '' +labels: '' +assignees: '' + +--- + + diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 8d498abc..2b0c282c 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -4,6 +4,10 @@ ## Checklist - -- [ ] I've added this contribution to the `changelog.md`. -- [ ] I've added my name to the `AUTHORS` file (or it's already there). + +- [ ] I added this contribution to the `changelog.md` file. +- [ ] I added my name to the `AUTHORS` file (or it's already there). +- [ ] To lint and format the code, I ran + ```bash + uv run ruff check && uv run ruff format && uv run mypy --install-types . + ``` diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 00000000..12301490 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,6 @@ +version: 2 +updates: + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "daily" diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index fb34daa3..50860a38 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,40 +1,35 @@ -name: mycli +name: CI on: pull_request: paths-ignore: - '**.md' + - '**.rst' + - 'LICENSE.txt' + - 'doc/**/*.txt' + - '**/AUTHORS' + - '**/SPONSORS' + - '**/TIPS' jobs: - linux: + tests: + name: Tests + runs-on: ubuntu-latest + strategy: + fail-fast: false matrix: - python-version: [ - '3.8', - '3.9', - '3.10', - '3.11', - '3.12', - ] - include: - - python-version: '3.8' - os: ubuntu-20.04 # MySQL 8.0.36 - - python-version: '3.9' - os: ubuntu-20.04 # MySQL 8.0.36 - - python-version: '3.10' - os: ubuntu-22.04 # MySQL 8.0.36 - - python-version: '3.11' - os: ubuntu-22.04 # MySQL 8.0.36 - - python-version: '3.12' - os: ubuntu-22.04 # MySQL 8.0.36 - - runs-on: ${{ matrix.os }} + python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"] + steps: + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - - uses: actions/checkout@v4 + - uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0 + with: + version: "latest" - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 + uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0 with: python-version: ${{ matrix.python-version }} @@ -43,10 +38,7 @@ jobs: sudo /etc/init.d/mysql start - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install -r requirements-dev.txt - pip install --no-cache-dir -e . + run: uv sync --all-extras -p ${{ matrix.python-version }} - name: Wait for MySQL connection run: | @@ -58,14 +50,43 @@ jobs: env: PYTEST_PASSWORD: root PYTEST_HOST: 127.0.0.1 + TERM: xterm + run: | + uv run tox -e py${{ matrix.python-version }} + + test-no-extras: + name: Tests Without Extras + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + + - uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0 + with: + version: "latest" + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0 + with: + python-version: '3.13' + + - name: Start MySQL run: | - ./setup.py test --pytest-args="--cov-report= --cov=mycli" + sudo /etc/init.d/mysql start - - name: Lint + - name: Install dependencies + run: uv sync --extra dev -p python3.13 + + - name: Wait for MySQL connection run: | - ./setup.py lint --branch=HEAD + while ! mysqladmin ping --host=localhost --port=3306 --user=root --password=root --silent; do + sleep 5 + done - - name: Coverage + - name: Pytest / behave + env: + PYTEST_PASSWORD: root + PYTEST_HOST: 127.0.0.1 + TERM: xterm run: | - coverage combine - coverage report + uv run tox -e py3.13 diff --git a/.github/workflows/codex-review.yml b/.github/workflows/codex-review.yml new file mode 100644 index 00000000..778269fd --- /dev/null +++ b/.github/workflows/codex-review.yml @@ -0,0 +1,86 @@ +name: Codex Review + +on: + pull_request_target: + types: [opened, labeled, reopened, ready_for_review] + paths-ignore: + - '**.md' + - '**.rst' + - 'LICENSE.txt' + - 'doc/**/*.txt' + - '**/AUTHORS' + - '**/SPONSORS' + - '**/TIPS' + +jobs: + codex-review: + if: github.event.pull_request.draft == false || (github.event.action == 'labeled' && contains(github.event.pull_request.labels.*.name, 'codex')) + runs-on: ubuntu-latest + permissions: + contents: read + outputs: + final_message: ${{ steps.run_codex.outputs.final-message }} + + steps: + - name: Check out PR merge commit + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + with: + ref: refs/pull/${{ github.event.pull_request.number }}/merge + + - name: Fetch base and head refs + run: | + git fetch --no-tags origin \ + ${{ github.event.pull_request.base.ref }} \ + +refs/pull/${{ github.event.pull_request.number }}/head + + - name: Run Codex review + id: run_codex + uses: openai/codex-action@e0fdf01220eb9a88167c4898839d273e3f2609d1 # v1.8 + env: + # Use env variables to handle untrusted metadata safely + PR_TITLE: ${{ github.event.pull_request.title }} + PR_BODY: ${{ github.event.pull_request.body }} + with: + openai-api-key: ${{ secrets.OPENAI_API_KEY }} + prompt: | + You are reviewing PR #${{ github.event.pull_request.number }} for ${{ github.repository }}. + + Only review changes introduced by this PR: + git log --oneline ${{ github.event.pull_request.base.sha }}...${{ github.event.pull_request.head.sha }} + + Focus on: + - correctness bugs and regressions + - security concerns + - missing tests or edge cases + + Keep feedback concise and actionable. + + Pull request title and body: + ---- + $PR_TITLE + $PR_BODY + + post-feedback: + runs-on: ubuntu-latest + needs: codex-review + if: needs.codex-review.outputs.final_message != '' + permissions: + issues: write + pull-requests: write + + steps: + - name: Post Codex review as PR comment + uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0 + env: + CODEX_FINAL_MESSAGE: | + ${{ format('## Codex Review + {0}', needs.codex-review.outputs.final_message) }} + with: + github-token: ${{ github.token }} + script: | + await github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.payload.pull_request.number, + body: process.env.CODEX_FINAL_MESSAGE, + }); diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 00000000..1dc3c720 --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,29 @@ +name: Lint + +on: + pull_request: + paths-ignore: + - '**.md' + - '**.rst' + - 'LICENSE.txt' + - 'doc/**/*.txt' + - '**/AUTHORS' + - '**/SPONSORS' + - '**/TIPS' + +jobs: + linters: + name: Linters + runs-on: ubuntu-latest + + steps: + - name: Check out Git repository + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + + - name: Run ruff check + uses: astral-sh/ruff-action@0ce1b0bf8b818ef400413f810f8a11cdbda0034b # v4.0.0 + + - name: Run ruff format + uses: astral-sh/ruff-action@0ce1b0bf8b818ef400413f810f8a11cdbda0034b # v4.0.0 + with: + args: 'format --check' diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml new file mode 100644 index 00000000..9a31c7a1 --- /dev/null +++ b/.github/workflows/publish.yml @@ -0,0 +1,107 @@ +name: Publish Python Package + +on: + release: + types: [created] + +permissions: + contents: read + +jobs: + docs: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + + - name: Require release changelog form + run: | + if grep -q TBD changelog.md; then false; fi + + test: + runs-on: ubuntu-latest + needs: [docs] + continue-on-error: true + + strategy: + matrix: + python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"] + + steps: + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + - uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0 + with: + version: "latest" + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0 + with: + python-version: ${{ matrix.python-version }} + + - name: Start MySQL + run: | + sudo /etc/init.d/mysql start + + - name: Install dependencies + run: uv sync --all-extras -p ${{ matrix.python-version }} + + - name: Wait for MySQL connection + run: | + while ! mysqladmin ping --host=localhost --port=3306 --user=root --password=root --silent; do + sleep 5 + done + + - name: Pytest / behave + env: + PYTEST_PASSWORD: root + PYTEST_HOST: 127.0.0.1 + run: | + uv run tox -e py${{ matrix.python-version }} + + # arguably this should be made identical to CI for PRs + - name: Run Style Checks + run: uv run tox -e style + + build: + runs-on: ubuntu-latest + needs: [test] + + steps: + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + - uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0 + with: + version: "latest" + + - name: Set up Python + uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0 + with: + python-version: '3.13' + + - name: Install dependencies + run: uv sync --all-extras -p 3.13 + + - name: Build + run: uv build + + - name: Store the distribution packages + uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1 + with: + name: python-packages + path: dist/ + + publish: + name: Publish to PyPI + runs-on: ubuntu-latest + if: startsWith(github.ref, 'refs/tags/') + needs: [build] + environment: release + permissions: + id-token: write + steps: + - name: Download distribution packages + uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1 + with: + name: python-packages + path: dist/ + - name: Publish to PyPI + uses: pypa/gh-action-pypi-publish@cef221092ed1bacb1cc03d23a2d87d1d172e277b # v1.14.0 diff --git a/.github/workflows/typecheck.yml b/.github/workflows/typecheck.yml new file mode 100644 index 00000000..ccae747d --- /dev/null +++ b/.github/workflows/typecheck.yml @@ -0,0 +1,38 @@ +name: Typecheck + +on: + pull_request: + paths-ignore: + - '**.md' + - '**.rst' + - 'LICENSE.txt' + - 'doc/**/*.txt' + - '**/AUTHORS' + - '**/SPONSORS' + - '**/TIPS' + +jobs: + typecheck: + name: Typecheck + runs-on: ubuntu-latest + + steps: + - name: Check out Git repository + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + + - name: Set up Python + uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0 + with: + python-version: '3.13' + + - uses: astral-sh/setup-uv@08807647e7069bb48b6ef5acd8ec9567f424441b # v8.1.0 + with: + version: 'latest' + + - name: Install dependencies + run: uv sync --all-extras + + - name: Run mypy + run: | + uv run --no-sync --frozen -- python -m ensurepip + uv run --no-sync --frozen -- python -m mypy --no-pretty --install-types --non-interactive . diff --git a/.gitignore b/.gitignore index 970fcd4f..3489bec1 100644 --- a/.gitignore +++ b/.gitignore @@ -12,6 +12,13 @@ .cache/ .coverage .coverage.* +.mypy_cache/ +.pytest_cache/ +.ruff_cache/ +.tox/ .venv/ venv/ + +.myclirc +uv.lock diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 00000000..51766e29 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,122 @@ +# MyCli + +A command line client for MySQL with auto-completion and syntax highlighting. + +## Project Structure + +/ # repository root +├── .github/ # GitHub Actions and configuration +├── pyproject.toml # project configuration +├── doc/ # documentation +├── mycli/ # application source +├── mycli/__init__.py # provides version number +├── mycli/clibuffer.py # prompt_toolkit buffer utilities +├── mycli/clistyle.py # prompt_toolkit style utilities +├── mycli/clitoolbar.py # prompt_toolkit toolbar utilities +├── mycli/compat.py # OS compatibility helpers +├── mycli/completion_refresher.py # populates a `SQLCompleter` object in a background thread +├── mycli/config.py # configuration file readers and utilities +├── mycli/constants.py # shared constants +├── mycli/key_bindings.py # prompt_toolkit key binding utilities +├── mycli/lexer.py # extends `MySqlLexer` from Pygments +├── mycli/magic.py # Jupyter notebook magics +├── mycli/main.py # CLI main, configuration processing, and REPL +├── mycli/main_modes/ # main execution paths +├── mycli/main_modes/batch.py # batch mode execution path +├── mycli/myclirc # project-level configuration file +├── mycli/packages/ # application packages +├── mycli/packages/batch_utils.py # utilities for `--batch` mode +├── mycli/packages/checkup.py # implementation of `--checkup` mode +├── mycli/packages/cli_utils.py # utilities for parsing CLI arguments +├── mycli/packages/completion_engine.py # implementation of completion suggestions +├── mycli/packages/filepaths.py # utilities for files, including completion suggestions +├── mycli/packages/hybrid_redirection.py # implementation of shell-style redirects +├── mycli/packages/interactive_utils.py # utilities for confirming on destructive statements +├── mycli/packages/paramiko_stub/ # stub in case the Paramiko library is not installed +├── mycli/packages/sql_utils.py # utilities for parsing SQL statements +├── mycli/packages/ptoolkit/ # extends prompt_toolkit +├── mycli/packages/shortcuts.py # utilities for keyboard shortcuts +├── mycli/packages/special/ # implementation of mycli special commands +├── mycli/packages/sqlresult.py # the `SQLResult` dataclass for holding responses +├── mycli/packages/string_utils.py # generic string utilities +├── mycli/packages/tabular_output/ # extends cli_helper with additional output formats +├── mycli/sqlcompleter.py # offers SQL completions +├── mycli/sqlexecute.py # runs SQL queries +├── test/conftest.py # pytest configuration +├── test/features/ # behave tests +├── test/myclirc # mycli configuration used for tests +├── test/mylogin.cnf # `mylogin.cnf` example used for tests +├── test/pytests/ # pytest tests +└── test/utils.py # shared utilities for tests + +## Development + +### Python + +#### Python Dependency Management + +This repo uses `uv` for dependency management. **Always** prefix Python +commands with `uv run`. Example: + +```bash +uv run -- python script.py +``` + +#### Python Typing + +This repo uses type annotations which are checked by `mypy`. **Always** add +type annotations, and always check new code with `uv run -- mypy --install-types --non-interactive script.py`. + +Use lower-case type annotations such as `tuple`, not upper-case type +annotations such as `Tuple`. + +Use `Type | None` instead of `Optional[Type]`. + +#### Python Testing + +Tests are coordinated by `tox`, and include both `pytest` and `behave` tests. +To run the full test suite, execute `uv run -- tox`. + +#### Python Compatibility + +Use Python features available from Python 3.10 through Python 3.14. +Compatibility with Python 3.9 is not needed. + +#### Python Style + +Import style: prefer `from package import name` over `import package.name as name`. + +Quoting style: prefer single quotes for new code, but do not remove double quotes +from existing code. + +#### Python Environment + + * Package manager: `uv` (not pip) + * Formatter: `uv run -- ruff format` + * Linter: `uv run -- ruff check` + * Type checker: `uv run -- mypy --install-types --non-interactive` + +### Git Workflows + +#### Git Commit Messages + + * Use the present tense. + * Keep the first line under 50 characters in length. + * Keep the second line blank. + * Keep all other lines under 72 characters in length. + * Reference issue numbers when available. + +#### Generating PRs + +When generating a PR, follow the instructions in `.github/PULL_REQUEST_TEMPLATE.md`: + + * Add new author names to `mycli/AUTHORS`. + * Add a new entry to `changelog.md`. + +### Code Comments + +Keep comments concise and direct. Use full sentences, ending with a period. + +### See Also + +See also the file `CONTRIBUTING.md`. diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index cac4f04e..945b0790 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -19,78 +19,60 @@ You'll always get credit for your work. $ git remote add upstream git@github.com:dbcli/mycli.git ``` -4. Set up a [virtual environment](http://docs.python-guide.org/en/latest/dev/virtualenvs) +4. Set up [uv](https://docs.astral.sh/uv/getting-started/installation/) for development: ```bash $ cd mycli - $ pip install virtualenv - $ virtualenv mycli_dev + $ uv sync --extra dev ``` - We've just created a virtual environment that we'll use to install all the dependencies - and tools we need to work on mycli. Whenever you want to work on mycli, you - need to activate the virtual environment: + We've just created a virtual environment and installed all the dependencies + and tools we need to work on mycli. - ```bash - $ source mycli_dev/bin/activate - ``` - - When you're done working, you can deactivate the virtual environment: - - ```bash - $ deactivate - ``` - -5. Install the dependencies and development tools: - - ```bash - $ pip install -r requirements-dev.txt - $ pip install --editable . - ``` - -6. Create a branch for your bugfix or feature based off the `main` branch: +5. Create a branch for your bugfix or feature based off the `main` branch: ```bash $ git checkout -b main ``` -7. While you work on your bugfix or feature, be sure to pull the latest changes from `upstream`. This ensures that your local codebase is up-to-date: +6. While you work on your bugfix or feature, be sure to pull the latest changes from `upstream`. This ensures that your local codebase is up-to-date: ```bash $ git pull upstream main ``` -8. When your work is ready for the mycli team to review it, push your branch to your fork: +7. When your work is ready for the mycli team to review it, push your branch to your fork: ```bash $ git push origin ``` -9. [Create a pull request](https://help.github.com/articles/creating-a-pull-request-from-a-fork/) +8. [Create a pull request](https://help.github.com/articles/creating-a-pull-request-from-a-fork/) on GitHub. -## Running the Tests +## Running mycli -While you work on mycli, it's important to run the tests to make sure your code -hasn't broken any existing functionality. To run the tests, just type in: +To run mycli with your local changes: ```bash -$ ./setup.py test +$ uv run mycli ``` -Mycli supports Python 2.7 and 3.4+. You can test against multiple versions of -Python by running tox: + +## Running the Tests + +While you work on mycli, it's important to run the tests to make sure your code +hasn't broken any existing functionality. To run the tests, just type in: ```bash -$ tox +$ uv run tox ``` - ### Test Database Credentials -The tests require a database connection to work. You can tell the tests which +Some tests require a database connection to work. You can tell the tests which credentials to use by setting the applicable environment variables: ```bash @@ -98,10 +80,10 @@ $ export PYTEST_HOST=localhost $ export PYTEST_USER=mycli $ export PYTEST_PASSWORD=myclirocks $ export PYTEST_PORT=3306 -$ export PYTEST_CHARSET=utf8 +$ export PYTEST_CHARSET=utf8mb4 ``` -The default values are `localhost`, `root`, no password, `3306`, and `utf8`. +The default values are `localhost`, `root`, no password, `3306`, and `utf8mb4`. You only need to set the values that differ from the defaults. If you would like to run the tests as a user with only the necessary privileges, @@ -125,43 +107,12 @@ You can check this by running: $ readlink -f $(which ex) ``` +# Github PR checklist +- add the contribution to the `changelog.md` +- add your name to the `AUTHORS` file (or it's already there). +- run `uv run ruff check && uv run ruff format && uv run mypy --install-types .` -## Coding Style - -Mycli requires code submissions to adhere to -[PEP 8](https://www.python.org/dev/peps/pep-0008/). -It's easy to check the style of your code, just run: - -```bash -$ ./setup.py lint -``` - -If you see any PEP 8 style issues, you can automatically fix them by running: - -```bash -$ ./setup.py lint --fix -``` - -Be sure to commit and push any PEP 8 fixes. ## Releasing a new version of mycli -You have been made the maintainer of `mycli`? Congratulations! We have a release script to help you: - -```sh -> python release.py --help -Usage: release.py [options] - -Options: - -h, --help show this help message and exit - -c, --confirm-steps Confirm every step. If the step is not confirmed, it - will be skipped. - -d, --dry-run Print out, but not actually run any steps. -``` - -To release a new version of the package: - -* Create and merge a PR to bump the version in the changelog ([example PR](https://github.com/dbcli/mycli/pull/1043)). -* Pull `main` and bump the version number inside `mycli/__init__.py`. Do not check in - the release script will do that. -* Make sure you have the dev requirements installed: `pip install -r requirements-dev.txt -U --upgrade-strategy only-if-needed`. -* Finally, run the release script: `python release.py`. +Create a new [release](https://github.com/dbcli/mycli/releases) in Github. This will trigger a Github action which will run all the tests, build the wheel and upload it to PyPI. diff --git a/LICENSE.txt b/LICENSE.txt index 7b4904e2..7db7b58b 100644 --- a/LICENSE.txt +++ b/LICENSE.txt @@ -1,34 +1,27 @@ +Copyright (c) 2015-2026, mycli maintainers All rights reserved. -Redistribution and use in source and binary forms, with or without modification, -are permitted provided that the following conditions are met: +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. -* Redistributions in binary form must reproduce the above copyright notice, this - list of conditions and the following disclaimer in the documentation and/or - other materials provided with the distribution. +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. -* Neither the name of the {organization} nor the names of its +* Neither the name of mycli nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND -ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR -ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES -(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; -LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON -ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS -SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -------------------------------------------------------------------------------- - -This program also bundles with it python-tabulate -(https://pypi.python.org/pypi/tabulate) library. This library is licensed under -MIT License. - -------------------------------------------------------------------------------- +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/MANIFEST.in b/MANIFEST.in index 04f4d9a9..c885fa72 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,5 +1,4 @@ -include LICENSE.txt *.md *.rst requirements-dev.txt screenshots/* -include tasks.py .coveragerc tox.ini +include LICENSE.txt *.md *.rst doc/screenshots/* recursive-include test *.cnf recursive-include test *.feature recursive-include test *.py diff --git a/README.md b/README.md index 0a431437..ff6d99da 100644 --- a/README.md +++ b/README.md @@ -4,118 +4,44 @@ A command line client for MySQL that can do auto-completion and syntax highlighting. -HomePage: [http://mycli.net](http://mycli.net) -Documentation: [http://mycli.net/docs](http://mycli.net/docs) +Homepage: [https://mycli.net](https://mycli.net) +Documentation: [https://mycli.net/docs](https://mycli.net/docs) -![Completion](screenshots/tables.png) -![CompletionGif](screenshots/main.gif) +![Completion](doc/screenshots/tables.png) +![CompletionGif](doc/screenshots/main.gif) -Postgres Equivalent: [http://pgcli.com](http://pgcli.com) +Postgres Equivalent: [https://pgcli.com](https://pgcli.com) Quick Start ----------- -If you already know how to install python packages, then you can install it via pip: +If you already know how to install Python packages, then you can install it via `pip`: -You might need sudo on linux. +You might need sudo on Linux. -``` -$ pip install -U mycli +```bash +pip install -U 'mycli[all]' ``` or -``` -$ brew update && brew install mycli # Only on macOS +```bash +brew update && brew install mycli # Only on macOS ``` or -``` -$ sudo apt-get install mycli # Only on debian or ubuntu +```bash +sudo apt-get install mycli # Only on Debian or Ubuntu ``` ### Usage - $ mycli --help - Usage: mycli [OPTIONS] [DATABASE] - - A MySQL terminal client with auto-completion and syntax highlighting. - - Examples: - - mycli my_database - - mycli -u my_user -h my_host.com my_database - - mycli mysql://my_user@my_host.com:3306/my_database - - Options: - -h, --host TEXT Host address of the database. - -P, --port INTEGER Port number to use for connection. Honors - $MYSQL_TCP_PORT. - - -u, --user TEXT User name to connect to the database. - -S, --socket TEXT The socket file to use for connection. - -p, --password TEXT Password to connect to the database. - --pass TEXT Password to connect to the database. - --ssh-user TEXT User name to connect to ssh server. - --ssh-host TEXT Host name to connect to ssh server. - --ssh-port INTEGER Port to connect to ssh server. - --ssh-password TEXT Password to connect to ssh server. - --ssh-key-filename TEXT Private key filename (identify file) for the - ssh connection. - - --ssh-config-path TEXT Path to ssh configuration. - --ssh-config-host TEXT Host to connect to ssh server reading from ssh - configuration. - - --ssl Enable SSL for connection (automatically - enabled with other flags). - --ssl-ca PATH CA file in PEM format. - --ssl-capath TEXT CA directory. - --ssl-cert PATH X509 cert in PEM format. - --ssl-key PATH X509 key in PEM format. - --ssl-cipher TEXT SSL cipher to use. - --tls-version [TLSv1|TLSv1.1|TLSv1.2|TLSv1.3] - TLS protocol version for secure connection. - - --ssl-verify-server-cert Verify server's "Common Name" in its cert - against hostname used when connecting. This - option is disabled by default. - - -V, --version Output mycli's version. - -v, --verbose Verbose output. - -D, --database TEXT Database to use. - -d, --dsn TEXT Use DSN configured into the [alias_dsn] - section of myclirc file. - - --list-dsn list of DSN configured into the [alias_dsn] - section of myclirc file. - - --list-ssh-config list ssh configurations in the ssh config - (requires paramiko). - - -R, --prompt TEXT Prompt format (Default: "\t \u@\h:\d> "). - -l, --logfile FILENAME Log every query and its results to a file. - --defaults-group-suffix TEXT Read MySQL config groups with the specified - suffix. - - --defaults-file PATH Only read MySQL options from the given file. - --myclirc PATH Location of myclirc file. - --auto-vertical-output Automatically switch to vertical output mode - if the result is wider than the terminal - width. - - -t, --table Display batch output in table format. - --csv Display batch output in CSV format. - --warn / --no-warn Warn before running a destructive query. - --local-infile BOOLEAN Enable/disable LOAD DATA LOCAL INFILE. - -g, --login-path TEXT Read this path from the login file. - -e, --execute TEXT Execute command and quit. - --init-command TEXT SQL statement to execute after connecting. - --charset TEXT Character set for MySQL session. - --password-file PATH File or FIFO path containing the password - to connect to the db if not specified otherwise - --help Show this message and exit. +See +```bash +mycli --help +``` Features -------- @@ -124,19 +50,26 @@ Features * Auto-completion as you type for SQL keywords as well as tables, views and columns in the database. +* Fuzzy history search using [fzf](https://github.com/junegunn/fzf). * Syntax highlighting using Pygments. * Smart-completion (enabled by default) will suggest context-sensitive completion. - `SELECT * FROM ` will only show table names. - `SELECT * FROM users WHERE ` will only show column names. * Support for multiline queries. * Favorite queries with optional positional parameters. Save a query using - `\fs alias query` and execute it with `\f alias` whenever you need. + `\fs ` and execute it with `\f `. * Timing of sql statements and table rendering. -* Config file is automatically created at ``~/.myclirc`` at first launch. * Log every query and its results to a file (disabled by default). -* Pretty prints tabular data (with colors!) +* Pretty print tabular data (with colors!). * Support for SSL connections -* Some features are only exposed as [key bindings](doc/key_bindings.rst) +* Shell-style trailing redirects with `$>`, `$>>` and `$|` operators. +* Support for querying LLMs with context derived from your schema. +* Support for storing passwords in the system keyring. + +Mycli creates a config file `~/.myclirc` on first run; you can use the +options in that file to configure the above features, and more. + +Some features are only exposed as [key bindings](doc/key_bindings.rst). Contributions: -------------- @@ -147,20 +80,17 @@ get this running in a development setup. https://github.com/dbcli/mycli/blob/main/CONTRIBUTING.md -Please feel free to reach out to me if you need help. - -My email: amjith.r@gmail.com -Twitter: [@amjithr](http://twitter.com/amjithr) +## Additional Install Instructions: -## Detailed Install Instructions: +These are some alternative ways to install mycli that are not managed by our team but provided by OS package maintainers. These packages could be slightly out of date and take time to release the latest version. ### Arch, Manjaro You can install the mycli package available in the AUR: ``` -$ yay -S mycli +yay -S mycli ``` ### Debian, Ubuntu @@ -168,7 +98,7 @@ $ yay -S mycli On Debian, Ubuntu distributions, you can easily install the mycli package using apt: ``` -$ sudo apt-get install mycli +sudo apt-get install mycli ``` ### Fedora @@ -176,25 +106,39 @@ $ sudo apt-get install mycli Fedora has a package available for mycli, install it using dnf: ``` -$ sudo dnf install mycli +sudo dnf install mycli ``` ### Windows -Follow the instructions on this blogpost: http://web.archive.org/web/20221006045208/https://www.codewall.co.uk/installing-using-mycli-on-windows/ +#### Option 1: Native Windows + +Install the `less` pager, for example by `scoop install less`. + +Follow the instructions on this blogpost: https://web.archive.org/web/20221006045208/https://www.codewall.co.uk/installing-using-mycli-on-windows/ + +**Mycli is not tested on Windows**, but the libraries used in the app are Windows-compatible. +This means it should work without any modifications, but isn't supported. + +PRs to add native Windows testing to Mycli CI would be welcome! + +#### Option 2: WSL + +Everything should work as expected in WSL. This is a good option for using +Mycli on Windows. ### Thanks: -This project was funded through kickstarter. My thanks to the [backers](http://mycli.net/sponsors) who supported the project. +This project was funded through kickstarter. My thanks to the [backers](https://mycli.net/sponsors) who supported the project. A special thanks to [Jonathan Slenders](https://twitter.com/jonathan_s) for -creating [Python Prompt Toolkit](http://github.com/jonathanslenders/python-prompt-toolkit), +creating [Python Prompt Toolkit](https://github.com/jonathanslenders/python-prompt-toolkit), which is quite literally the backbone library, that made this app possible. Jonathan has also provided valuable feedback and support during the development of this app. -[Click](http://click.pocoo.org/) is used for command line option parsing +[Click](https://palletsprojects.com/projects/click) is used for command line option parsing and printing error messages. Thanks to [PyMysql](https://github.com/PyMySQL/PyMySQL) for a pure python adapter to MySQL database. @@ -202,17 +146,22 @@ Thanks to [PyMysql](https://github.com/PyMySQL/PyMySQL) for a pure python adapte ### Compatibility -Mycli is tested on macOS and Linux, and requires Python 3.7 or better. +Mycli is tested on macOS and Linux, and requires Python 3.10 or better. + +To connect to MySQL versions earlier than 5.5, you may need to set the following in `~/.myclirc`: + +``` +# character set for connections without --charset being set at the CLI +default_character_set = utf8 +``` -**Mycli is not tested on Windows**, but the libraries used in this app are Windows-compatible. -This means it should work without any modifications. If you're unable to run it -on Windows, please [file a bug](https://github.com/dbcli/mycli/issues/new). +or set `--charset=utf8` when invoking MyCLI. ### Configuration and Usage -For more information on using and configuring mycli, [check out our documentation](http://mycli.net/docs). +For more information on using and configuring mycli, [check out our documentation](https://mycli.net/docs). Common topics include: -- [Configuring mycli](http://mycli.net/config) -- [Using/Disabling the pager](http://mycli.net/pager) -- [Syntax colors](http://mycli.net/syntax) +- [Configuring mycli](https://mycli.net/config) +- [Using/Disabling the pager](https://mycli.net/pager) +- [Syntax colors](https://mycli.net/syntax) diff --git a/changelog.md b/changelog.md index ffe31314..990d4e48 100644 --- a/changelog.md +++ b/changelog.md @@ -1,32 +1,1217 @@ -Upcoming Release (TBD) +Upcoming (TBD) +============== + +Features +--------- +* Update `cli_helpers` to v2.15.0 for `mysql_heavy` table format. + + +Bug Fixes +--------- +* Respect `history_file` setting in the `[main]` section of `~/.myclirc`. +* Adapt test suite to pygments v2.20.0. + +Internal +--------- +* Factor `app_state.py`, `cli_args.py`, and `output.py` out of `main.py`. + + +1.72.1 (2026/05/11) +============== + +Bug Fixes +--------- +* Update `sqlglot` to v30.7.0 to fix has_bit_strings error. + + +1.72.0 (2026/05/08) +============== + +Features +--------- +* Allow styling prompts with HTML-like tags. + + +Bug Fixes +--------- +* Gracefully fail on background completion-refresh connection issues. + + +Documentation +--------- +* Document the `\g` special command to send a query. + + +Internal +--------- +* Independent case-sensitivity for special-command aliases. + + +1.71.0 (2026/05/01) +============== + +Features +--------- +* Add more output to the `status` command. +* Respond to `help ` on builtin special commands. + + +Documentation +--------- +* Give example for ANSI prompt colors in `~/.myclirc`. +* Fix typos in `TIPS` file. +* Lightly reorganize `AUTHORS` file. + + +Internal +--------- +* Remove unused fixture data. +* More test coverage for completion prefetch. +* More test coverage for `--resume`. +* Upgrade `cli_helpers` dependency to v2.14.0. +* Require `prompt_toolkit>=3.0.41`. + + +1.70.0 (2026/04/24) +============== + +Features +--------- +* Add option to prefetch completion metadata for some or all schemas. +* Save fetched completion metadata when switching schemas. + + +1.69.0 (2026/04/20) +============== + +Features +--------- +* Remove undocumented `%mycli` Jupyter magic. +* Add `--quiet` option, and let `--verbose` be given multiple times. +* Add `--resume` to replay `--checkpoint` files with `--batch`. + + +Bug Fixes +--------- +* Make LLM timings use the same format as other timings. + + +Internal +--------- +* Commentary and organization in default/package myclirc file. + + +1.68.1 (2026/04/16) +============== + +Bug Fixes +--------- +* Upgrade `sqlglot` to v30.4.3, which may fix a build problem. + + +1.68.0 (2026/04/13) +============== + +Features +--------- +* Continue to expand TIPS. +* Make `--progress` and `--checkpoint` strictly by statement. +* Allow more characters in passwords read from a file. +* Show sponsors and contributors separately in startup messages. +* Add support for expired password (sandbox) mode (#440). +* Make balanced-bracket highlight colors configurable. +* Don't persist password-change statements to history file. + + +Bug Fixes +--------- +* Fix issue stripping multi-character end-of-statement delimiters. +* More conservative content truncation when sending to LLM APIs. +* More careful removal of redundant fuzzy completion suggestions. +* Fix a corner case when listing an empty list of favorite queries. +* Better completions refresh on changing databases or ALTERs. +* Make the return value of `FavoriteQueries.list()` a copy. +* Make multi-line detection and special cases more robust. +* Run empty `--execute` arguments instead of ignoring the flag. +* Exit with error when the `--batch` argument is an empty string. +* Avoid logging SSH passwords. + + +Internal +--------- +* Add an `AGENTS.md`. +* Refactor `find_matches()` into smaller logical units. +* Greatly increase test coverage. +* Remove some unused code. +* Better label Codex PR reviews. +* Improve gitignored files. +* Continue improve naming for `prompt_toolkit` utilities. +* Run pytest tests in arbitrary order. +* Type annotation improvements for `parse_pygments_style()`. +* Upgrade `llm` dependency and set a minimum `pydantic_core` version. +* Refactor suggestion logic into declarative rules. +* Factor the `--batch` execution modes out of `main.py`. +* Move `--checkup` logic to the new `main_modes` with `--batch`. +* Move `--execute` logic to the new `main_modes` with `--batch`. +* Move `--list-dsn` logic to the new `main_modes` with `--batch`. +* Move `--list-ssh-config` logic to the new `main_modes` with `--batch`. +* Move REPL logic to the new `main_modes`, and refactor the REPL. +* Sort coverage report in tox suite. +* Skip more tests when a database connection is not present. +* Move SQL utilities to a new `sql_utils.py`. +* Move CLI utilities to a new `cli_utils.py`. +* Move keybinding utilities to a new `key_binding_utils.py`. +* Move interactive utilities to `interactive_utils.py`. +* Move special commands out of `main.py`. +* Modernize orthography of prompt_toolkit filters. +* Pin all GitHub Actions to hashes. +* Remove unused method `get_completions()`. + + +1.67.1 (2026/03/28) +============== + +Features +--------- +* Respond to `-h` alone with the helpdoc. +* Allow `--hostname` as an alias for `--host`. +* Suggest tables with foreign key relationships for JOIN and ON (#975). +* Deprecate `$DSN` environment variable in favor of `$MYSQL_DSN`. +* Add a `--progress` progress-bar option with `--batch`. + + +Bug Fixes +--------- +* Correct how password help is rendered in the helpdoc. +* Respect `--no-show-warnings`, overriding settings in `~/.myclirc`. + + +Internal +--------- +* Collect CLI arguments into a dataclass. +* Clean up generated files after test runs. +* Migrate toplevel tool configurations to `pyproject.toml`. +* Migrate other toplevel files to subdirectories. +* Gather `pytest` files into a subdirectory, separated from `behave` tests. +* Refactor: better naming for `prompt_toolkit` utilities. + + +1.66.0 (2026/03/21) +============== + +Features +--------- +* Add a `--batch` option as an alternative to STDIN. +* Deprecate `$MYSQL_UNIX_PORT` environment variable in favor of `$MYSQL_UNIX_SOCKET`. +* Support `--username` and `$MYSQL_USER` to set username. + + +Bug Fixes +--------- +* Revert suppression of warnings when `sqlglotrs` is installed (fixed upstream). +* Update `cli_helpers` to v2.12.0, fixing a `preserve_whitespace` bug with `tabulate`. + + +Internal +-------- +* Harden `codex-review` workflow against script injection from untrusted PR metadata. +* Handle Click exceptions by hand. +* Connect toolbar tests to the test database. + + +1.65.1 (2026/03/18) +============== + +Bug Fixes +--------- +* Require `sqlglot` 30.x. + + +1.65.0 (2026/03/16) +============== + +Features +--------- +* Add prompt format string for literal backslash. +* Add collation completions, and complete charsets in more positions. + + +Bug Fixes +--------- +* Suppress warnings when `sqlglotrs` is installed. +* Improve completions after operators, by recognizing more operators. + + +1.64.0 (2026/03/13) +============== + +Features +--------- +* Add `-r` raw mode to `system` command. +* Set timeouts, show exit codes, and improve formatting for `system` commands. +* Add a dependencies section to `--checkup`. + + +Bug Fixes +--------- +* Require `sqlglot` 29.x, suppressing a deprecation warning. + + +1.63.0 (2026/03/12) +============== + +Features +--------- +* Make short toolbar message show after one prompt. + + +Internal +--------- +* Migrate more repeated values to `constants.py`. +* Support `sqlglot` 28 and 29. + + +1.62.0 (2026/03/07) +============== + +Features +--------- +* Dynamic terminal titles based on prompt format strings. +* Ability to turn off the toolbar. +* Add completions for introducers on literals. +* Load whole-line autosuggest candidates in a background thread for speed. + + +Bug Fixes +--------- +* Improve query cancellation on control-c. +* Improve refresh of some format strings in the toolbar. +* Improve keyring storage, requiring re-entering most keyring passwords. +* Improve sentinel value for `--password` without argument. + + +Internal +--------- +* Require a more recent version of the `wcwidth` library. +* Make `safe_invalidate_display` function safer. + + +1.61.0 (2026/03/07) +============== + +Features +--------- +* Allow shorter timeout lengths after pressing Esc, for vi-mode. +* Let tab and control-space behaviors be configurable. +* Add short hostname prompt format string. + + +1.60.0 (2026/03/05) +============== + +Features +--------- +* Prioritize common functions in the "value" position. +* Improve value-position keywords. +* Allow warning-count in status output to be styled. + + +Bug Fixes +--------- +* Fix crash for completion edge case (#1668). +* Update to a `cli_helpers` version with a `tabulate` bugfix. + + +1.59.0 (2026/03/03) +============== + +Features +--------- +* Offer filename completions on more special commands, such as `\edit`. +* Allow styling of status and timings text. +* Set up customization of prompt/continuation colors in `~/.myclirc`. +* Allow customization of the toolbar with prompt format strings. +* Add warnings-count prompt format strings: `\w` and `\W`. +* Handle/document more attributes in the `[colors]` section of `~/.myclirc`. +* Enable customization of table border color/attributes in `~/.myclirc`. +* Complete much more precisely in the "value" position. + + +Bug Fixes +--------- +* Make toolbar widths consistent on toggle actions. +* Don't write ANSI prompt escapes to `tee` output. + + +Internal +--------- +* Use prompt_toolkit's `bell()`. +* Refactor `SQLResult` dataclass. +* Avoid depending on string matches into host info. +* Add more URL constants. +* Set `$VISUAL` whenever `$EDITOR` is set. +* Fix tempfile leak in test suite. +* Avoid refreshing the prompt unless needed. + + +1.58.0 (2026/02/28) +============== + +Features +--------- +* Add `\bug` command. +* Let the `F1` key open a browser to mycli.net/docs and emit help text. +* Add documentation index URL to inline help. +* Rewrite bottom toolbar, showing more statuses, but staying compact. +* Let `help ` list similar keywords when not found. +* Optionally highlight fuzzy search previews. +* Make `\edit` synonymous with the `\e` command. +* Add environment variable section to `--checkup`. + + +Bug Fixes +--------- +* Force a prompt_toolkit refresh after fzf history search to avoid display glitches. +* Include `status` footer in paged output. +* Ensure fullscreen in fuzzy history search. + + +Documentation +--------- +* Add `help ` to TIPS. +* Refine inline help descriptions. +* Add `$VISUAL` environment variable hint to TIPS. + + +Internal +--------- +* Better tests for `null_string` configuration option. +* Better cleanup of resources in the test suite. +* Simplify prettify/unprettify handlers. +* Make prettify/unprettify logic more robust. + + +1.57.0 (2026/02/25) +============== + +Features +--------- +* Add extra error output on connection failure for possible SSL mismatch (#1584). +* Bind alternate terminal sequences for function keys F2 - F4. +* Add `llm help` subcommand. +* Rewrite `help` table. +* Remove "info" counter from fzf history-search UI. + + +Bug Fixes +--------- +* Let interactive changes to the prompt format respect dynamically-computed values. +* Better handle arguments to `system cd`. +* Fix missing keepalives in `\e` prompt loop. +* Always strip trailing newlines with `\e`. +* Fix `\llm` without arguments, and remove debug output. + + +Documentation +--------- +* Startup tips: add right-arrow key binding. +* Startup tips: add control-space and the `min_completion_trigger` setting. +* Startup tips: add history-search bindings. +* Prefer `https` protocol over `http` in documentation. + + +Internal +--------- +* Remove outdated email address in `pyproject.toml`. +* Set well-known URL values in `pyproject.toml`. + + +1.56.0 (2026/02/23) +============== + +Features +--------- +* Let the `--dsn` argument accept literal DSNs as well as aliases. +* Accept `--character-set` as an alias for `--charset` at the CLI. +* Add SSL/TLS version to `status` output. +* Accept `socket` as a DSN query parameter. +* Accept new-style `ssl_mode` in DSN URI query parameters, to match CLI argument. +* Fully deprecate the built-in SSH functionality. +* Let `--keepalive-ticks` be set per-connection, as a CLI option or DSN parameter. +* Accept `character_set` as a DSN query parameter. +* Don't attempt SSL for local socket connections when in "auto" SSL mode. +* Add prompt format string for SSL/TLS version of the connection. +* Add prompt format strings for displaying uptime. +* Add batch mode to startup tips. +* Update startup tips with new options. + + +Bug Fixes +--------- +* Make `--ssl-capath` argument a directory. +* Allow users to use empty passwords without prompting or any configuration (#1584). +* Check the existence of a socket more directly in `status`. +* Allow multi-line SQL statements in batch mode on the standard input. +* Fix extraneous prompt refresh on every keystroke. + + +1.55.0 (2026/02/20) +============== + +Features +--------- +* `--checkup` now checks for external executables. + + +Bug Fixes +--------- +* Improve completion suggestions within backticks. +* Watch command now returns correct time when run as part of a multi-part query (#1565). +* Don't diagnose free-entry sections such as `[favorite_queries]` in `--checkup`. +* When accepting a filename completion, fill in leading `./` if given. + + +Internal +-------- +* Bump `cli_helpers` to non-yanked version. + + +1.54.1 (2026/02/17) +============== + +Bug Fixes +-------- +* Don't offer autocomplete suggestions when the cursor is within a string. +* Catch `getpwuid` error on unknown socket owner. + + +Internal +-------- +* Tune Codex reviews. +* Refactor `is_inside_quotes()` detection. + + +1.54.0 (2026/02/16) +============== + +Features +-------- +* Add many CLI flags to startup tips. +* Accept all special commands without trailing semicolons in multi-line mode. +* Add prompt format strings for socket connections. +* Optionally defer auto-completions until a minimum number of characters is typed. +* Make the completion interface more responsive using a background thread. +* Option to suppress control-d exit behavior. +* Better support Truecolor terminals. +* Ability to send app-layer keepalive pings to the server. +* Add `WITH`, `EXPLAIN`, and `LEFT JOIN` to favorite keyword suggestions. +* Let the Escape key cancel completion popups. + + +Bug Fixes +--------- +* Correct parameterization for completion queries. +* Grammar nits in help display. + + +Internal +-------- +* Prefer `yield from` over yielding in a loop. +* Update `ruff` linter and CI. +* Update `LICENSE.txt` for dates and GitHub detection. +* Update key feature list in `README.md`, syncing with web. +* Sync prompt format string commentary with web. +* Add a GitHub Actions workflow to run Codex review on pull requests. +* Remove vim-style exit sequence which had no effect. +* Pin dependencies more tightly in `pyproject.toml`. +* Exclude more documentation files from CI. + + +1.53.0 (2026/02/12) +============== + +Features +-------- +* Add all `~/.myclirc` entries/sections to startup tips. + + +Bug Fixes +--------- +* Fix `\dt+ table_name` returning empty results. +* Further bulletproof generating completions on stored procedures. + + +Internal +-------- +* Add GitHub Issue templates. + + +1.52.0 (2026/02/11) +============== + +Features +-------- +* Suggest tables/views that contain the given columns first when provided in a SELECT query. + + +Bug Fixes +-------- +* Reduce duplicated `--checkup` output. +* Handle errors generating completions on stored procedures. +* Fix whitespace/inline comments breaking destructive `UPDATE … WHERE` statement detection. + + +Internal +-------- +* Let CI ignore additional documentation files. +* Upgrade `cli_helpers` library to v2.10.0. +* Organize startup tips. + + +1.51.1 (2026/02/09) +============== + +Features +-------- +* Options to limit size of LLM prompts; cache LLM prompt data. +* Add startup usage tips. +* Move `main.ssl_mode` config option to `connection.default_ssl_mode`. +* Add "unsupported" and "deprecated" `--checkup` sections. + + +Bug Fixes +-------- +* Correct mangled schema info sent in LLM prompts. +* Give destructive warning on multi-table `UPDATE`s. + + +1.50.0 (2026/02/07) +============== + +Features +-------- +* Deprecate reading configuration values from `my.cnf` files. +* Add `--checkup` mode to show unconfigured new features. +* Add `binary_display` configuration option. + + +Bug Fixes +-------- +* Link to `--ssl`/`--no-ssl` GitHub issue in deprecation warning. +* Don't emit keyring-updated message unless needed. +* Include port and socket in keyring identifier. + + +1.49.0 (2026/02/02) +============== + +Features +-------- +* "Eager" completions for the `source` command, limited to `*.sql` files. +* Suggest column names from all tables in the current database after SELECT (#212). +* Put fuzzy completions more often to the bottom of the suggestion list. +* Store and retrieve passwords using the system keyring. + + +Bug Fixes +-------- +* Refactor completions for special commands, with minor casing fixes. +* Raise `--password-file` higher in the precedence of password specification. +* Fix regression: show username in password prompt. + + +Internal +-------- +* Remove `align_decimals` preprocessor, which had no effect. +* Fix TLS deprecation warning in test suite. +* Convert importlib read_text and open_text uses to newer files() syntax. +* Update Pull Request template. + + +1.48.0 (2026/01/27) +============== + +Features +-------- +* Right-align numeric columns, and make the behavior configurable. +* Add completions for stored procedures. +* Escape database completions. +* Offer completions on `CREATE TABLE ... LIKE`. +* Use 0x-style hex literals for binaries in SQL output formats. + + +Bug Fixes +-------- +* Better respect case when `keyword_casing` is `auto`. +* Fix error when selecting from an empty table. +* Let favorite queries contain special commands. +* Render binary values more consistently as hex literals. +* Offer format completions on special command `\Tr`/`redirectformat`. + + +1.47.0 (2026/01/24) +============== + +Features +-------- +* Add a `--checkpoint=` argument to log successful queries in batch mode. +* Add `--throttle` option for batch mode. + + +Bug Fixes +-------- +* Fix timediff output when the result is a negative value (#1113). +* Don't offer completions for numeric text. + + +1.46.0 (2026/01/22) +============== + +Features +-------- +* Add `--unbuffered` mode which fetches rows as needed, to save memory. +* Default to standards-compliant `utf8mb4` character set. +* Stream input from STDIN to consume less memory, adding `--noninteractive` and `--format=` CLI arguments. +* Remove suggested quoting on completions for identifiers with uppercase. +* Allow table names to be completed with leading schema names. +* Soft deprecate the built-in SSH features. +* Add true fuzzy-match completions with rapidfuzz. + + +Bug Fixes +-------- +* Fix CamelCase fuzzy matching. +* Place special commands first in the list of completion candidates, and remove duplicates. + + +1.45.0 (2026/01/20) +============== + +Features +-------- +* Make password options also function as flags. Reworked password logic to prompt user as early as possible (#341). +* More complete and up-to-date set of MySQL reserved words for completions. +* Place exact-leading completions first. +* Allow history file location to be configured. +* Make destructive-warning keywords configurable. +* Smarter fuzzy completion matches. + + +Bug Fixes +-------- +* Respect `--logfile` when using `--execute` or standard input at the shell CLI. +* Gracefully catch Paramiko parsing errors on `--list-ssh-config`. +* Downgrade to Paramiko 3.5.1 to avoid crashing on DSA SSH keys. +* Offer schema name completions in `GRANT ... ON` forms. + + +1.44.2 (2026/01/13) +============== + +Bug Fixes +-------- +* Update watch query output to display the correct execution time on all iterations (#763). +* Use correct database (if applicable) when reconnecting after a connection loss (#1437). + +Internal +-------- +* Create new data class to handle SQL/command results to make further code improvements easier. + + +1.44.1 (2026/01/10) +============== + +Bug Fixes +-------- +* Let `sqlparse` accept arbitrarily-large queries. + + +1.44.0 (2026/01/08) +============== + +Features +-------- +* Add enum value completions for WHERE/HAVING clauses. (#790) +* Add `show_favorite_query` config option to control query printing when running favorite queries. (#1118) + + +1.43.1 (2026/01/03) +============== + +Bug Fixes +-------- +* Prompt for password within SSL-auto retry flow. + + +1.43.0 (2026/01/02) +============== + +Features +-------- +* Update query processing functions to allow automatic show_warnings to work for more code paths like DDL. +* Add new ssl_mode config / --ssl-mode CLI option to control SSL connection behavior. This setting will supercede the + existing --ssl/--no-ssl CLI options, which are deprecated and will be removed in a future release. +* Rework reconnect logic to actually reconnect or create a new connection instead of simply changing the database (#746). +* Configurable string for missing values (NULLs) in outputs. + + +Bug Fixes +-------- +* Update the prompt display logic to handle an edge case where a socket is used without + a host being parsed from any other method (#707). + + +Internal +-------- +* Refine documentation for Windows. +* Target Python 3.10 for linting. +* Use fully-qualified pymysql exception classes. + + +1.42.0 (2025/12/20) +============== + +Features +-------- +* Add support for the automatic displaying of warnings after a SQL statement is executed. + May be set with the commands \W and \w, in the config file with show_warnings, or + with --show-warnings/--no-show-warnings on the command line. + + +Internal +-------- +* Improve robustness for flaky tests when publishing. +* Improve type annotations for latest mypy/type stubs. +* Set mypy version more strictly. + + +1.41.2 (2025/11/24) +============== + +Bug Fixes +-------- +* Close connection to server properly to avoid "Aborted connection" warnings in server logs. + +Internal +-------- +* Add ruff to developement dependencies. +* Update contributing guidelines to match GitHub pull request checklist. + + +1.41.1 (2025/11/15) +============== + +Bug Fixes +-------- +* Upgrade `click` to v8.3.1, resolving a longstanding pager bug. + + +Internal +-------- +* Include LLM dependencies in tox configuration. + + +1.41.0 (2025/11/01) +============== + +Features +-------- +* Make LLM dependencies an optional extra. + + +Bug Fixes +-------- +* Let LLM commands respect show-timing configuration. + + +Internal +-------- +* Add mypy to Pull Request template. +* Enable flake8-bugbear lint rules. +* Fix flaky editor-command tests in CI. +* Require release format of `changelog.md` when making a release. +* Improve type annotations on LLM driver. + + +1.40.0 (2025/10/14) +============== + +Features +-------- +* Support reconnecting to mysql server when the server restarts. + + +Internal +-------- +* Test on Python 3.14. +* Switch from pyaes to pycryptodomex as it seems to be more actively maintained. + + +1.39.1 (2025/10/06) +============== + +Bug Fixes +-------- +* Don't require `--ssl` argument when other SSL arguments are given. + + +1.39.0 (2025/09/30) +============== + +Features +-------- +* Support only Python 3.10+. + + +Bug Fixes +-------- +* Fixes use of incorrect ssl config after retrying connection with prompted password. +* Fix ssl_context always created. + + +Internal +-------- +Typing fix for `pymysql.connect()`. + + +1.38.4 (2025/09/06) +============== + +Bug Fixes +-------- +* Limit Alt-R bindings to Emacs mode. +* Fix timing being printed twice. + + +Internal +-------- +* Only read "my" configuration files once, rather than once per call to read_my_cnf_files. + + +1.38.3 (2025/08/21) +============== + +Bug Fixes +-------- +* Fix the infinite looping when `\llm` is called without args. + + +1.38.2 (2025/08/19) +====================== + +Bug Fixes +-------- +* Fix failure to save Favorite Queries. + + +1.38.1 (2025/08/19) +====================== + +Bug Fixes +-------- +* Partially fix Favorite Query completion crash. + + +Internal +-------- +* Improve CI workflow naming. + + +1.38.0 (2025/08/16) +====================== + +Features +-------- +* Add LLM support. + + +Bug Fixes +-------- +* Improve missing ssh-extras message. +* Fix repeated control-r in traditional reverse isearch. +* Fix spelling of `ssl-verify-server-cert` option. +* Improve handling of `ssl-verify-server-cert` False values. +* Guard against missing contributors file on startup. +* Friendlier errors on password-file failures. +* Better handle empty-string passwords. +* Permit empty-string passwords at the interactive prompt. + + +Internal +-------- +* Improve pull request template lint commands. +* Complete typehinting the non-test codebase. +* Modernization: conversion to f-strings. +* Modernization: remove more Python 2 compatibility logic. + + +1.37.1 (2025/07/28) +====================== + +Internal +-------- + +* Align LICENSE with SPDX format. +* Fix deprecated `license` specification format in `pyproject.toml`. + + +1.37.0 (2025/07/28) +====================== + +Features +-------- +* Show username in password prompt. +* Add `mysql` and `mysql_unicode` table formats. + + +Bug Fixes +-------- +* Help Windows installations find a working default pager. + + +Internal +-------- + +* Support only Python 3.9+ in `pyproject.toml`. +* Add linting suggestion to pull request template. +* Make CI names and properties more consistent. +* Enable typechecking for most of the non-test codebase. +* CI: turn off fail-fast matrix strategy. +* Remove unused Python 2 compatibility code. +* Also run CI tests without installing SSH extra dependencies. +* Update `cli_helpers` dependency, and list of table formats. + + +1.36.0 (2025/07/19) +====================== + +Features +-------- +* Make control-r reverse search style configurable. +* Make fzf search key bindings more compatible with traditional isearch. + + +Bug Fixes +-------- + +* Better reset after pipe command failures. + + +Internal +-------- + +* Add limited typechecking to CI. + + +1.35.0 (2025/07/18) +====================== + +Features +-------- + +* Support chained pipe operators such as `select first_name from users $| grep '^J' $| head -10`. +* Support trailing file redirects after pipe operators, such as `select 10 $| tail -1 $> ten.txt`. + + +1.34.4 (2025/07/15) +====================== + +Bug Fixes +-------- + +* Fix old-style `\pipe_once`. + + +1.34.3 (2025/07/14) +====================== + +Bug Fixes +-------- + +* Use only `communicate()` to communicate with subprocess. + + +1.34.2 (2025/07/12) +====================== + +Bug Fixes +-------- + +* Use plain `print()` to communicate with subprocess. + + +1.34.1 (2025/07/12) +====================== + +Internal +-------- + +* Bump cli_helpers dependency for corrected output formats. + + +1.34.0 (2025/07/11) +====================== + +Features +-------- + +* Post-save command hook for redirected output. + +Internal +-------- + +* Documentation cleanup. +* Bump cli_helpers dependency for more output formats. + + +1.33.0 (2025/07/07) ====================== -Bug Fixes: +Features +-------- + +* Keybindings to insert current date/datetime. +* Improve feedback when running external commands. +* Independent format for redirected output. +* Trailing shell-style redirect syntax. + + +Internal +-------- + +* Remove `requirements-dev.txt` in favor of uv/`pyproject.toml`. + + +1.32.0 (2025/07/04) +====================== + +Features +-------- + +* Support SSL query parameters on DSNs. +* More information and care on KeyboardInterrupt. + +Internal +-------- + +* Work on passing `ruff check` linting. +* Relax expectation for unreliable test. +* Bump sqlglot version to v26 and add rs extras. + + +1.31.2 (2025/05/01) +=================== + +Bug Fixes +--------- + +* Let table-name extraction work on multi-statement inputs. + + +Internal +-------- + +* Work on passing `ruff check` linting. +* Remove backward-compatibility hacks. +* Pin more GitHub Actions and add Dependabot support. +* Enable xpassing test. + + +1.31.1 (2025/04/25) +=================== + +Internal +-------- + +* skip style checks on Publish action + + +1.31.0 (NEVER RELEASED) +=================== + +Features +-------- +* Added explicit error handle to get_password_from_file with EAFP. +* Use the "history" scheme for fzf searches. +* Deduplicate history in fzf searches. +* Add a preview window to fzf history searches. + +Internal +-------- + +* New Project Lead: [Roland Walker](https://github.com/rolandwalker) +* Update sqlparse to <=0.6.0 +* Typing/lint fixes. + + +1.30.0 (2025/04/19) +=================== + +Features +-------- + +* DSN specific init-command in myclirc. Fixes (#1195) +* Add `\\g` to force the horizontal output. + + +1.29.2 (2024/12/11) +=================== + +Internal +-------- + +* Exclude tests from the python package. + +1.29.1 (2024/12/11) +=================== + +Internal +-------- + +* Fix the GH actions to publish a new version. + +1.29.0 (NEVER RELEASED) +======================= + +Bug Fixes ---------- +* fix SSL through SSH jump host by using a true python socket for a tunnel +* Fix mycli crash when connecting to Vitess -Internal: +Internal --------- -Features: +* Modernize to use PEP-621. Use `uv` instead of `pip` in GH actions. +* Remove Python 3.8 and add Python 3.13 in test matrix. + +1.28.0 (2024/11/10) +====================== + +Features --------- * Added fzf history search functionality. The feature can switch between the old implementation and the new one based on the presence of the fzf binary. +Bug Fixes +---------- + +* Fixes `Database connection failed: error('unpack requires a buffer of 4 bytes')` +* Only show keyword completions after * +* Enable fuzzy matching for keywords 1.27.2 (2024/04/03) =================== -Bug Fixes: +Bug Fixes ---------- * Don't use default prompt when one is not supplied to the --prompt option. - 1.27.1 (2024/03/28) =================== -Bug Fixes: +Bug Fixes ---------- * Don't install tests. @@ -34,24 +1219,22 @@ Bug Fixes: * Fix unexpected exception when using dsn without username & password (Thanks: [Will Wang]) * Let the `--prompt` option act normally with its predefined default value - - -Internal: +Internal --------- + * paramiko is newer than 2.11.0 now, remove version pinning `cryptography`. * Drop support for Python 3.7 - 1.27.0 (2023/08/11) =================== -Features: +Features --------- * Detect TiDB instance, show in the prompt, and use additional keywords. * Fix the completion order to show more commonly-used keywords at the top. -Bug Fixes: +Bug Fixes ---------- * Better handle empty statements in un/prettify @@ -60,139 +1243,146 @@ Bug Fixes: * Correctly report the version of TiDB. * Revised `botton` spelling mistakes with `bottom` in `mycli/clitoolbar.py` - 1.26.1 (2022/09/01) =================== -Bug Fixes: +Bug Fixes ---------- -* Require Python 3.7 in `setup.py` +* Require Python 3.7 in `setup.py` 1.26.0 (2022/09/01) =================== -Features: +Features --------- * Add `--ssl` flag to enable ssl/tls. * Add `pager` option to `~/.myclirc`, for instance `pager = 'pspg --csv'` (Thanks: [BuonOmo]) * Add prettify/unprettify keybindings to format the current statement using `sqlglot`. - -Features: +Features --------- + * Add `--tls-version` option to control the tls version used. -Internal: +Internal --------- + * Pin `cryptography` to suppress `paramiko` warning, helping CI complete and presumably affecting some users. * Upgrade some dev requirements * Change tests to always use databases prefixed with 'mycli_' for better security -Bug Fixes: +Bug Fixes ---------- + * Support for some MySQL compatible databases, which may not implement connection_id(). * Fix the status command to work with missing 'Flush_commands' (mariadb) * Ignore the user of the system [myslqd] config. - 1.25.0 (2022/04/02) =================== -Features: +Features --------- -* Add `beep_after_seconds` option to `~/.myclirc`, to ring the terminal bell after long queries. +* Add `beep_after_seconds` option to `~/.myclirc`, to ring the terminal bell after long queries. 1.24.4 (2022/03/30) =================== -Internal: +Internal --------- + * Upgrade Ubuntu VM for runners as Github has deprecated it -Bug Fixes: +Bug Fixes ---------- -* Change in main.py - Replace the `click.get_terminal_size()` with `shutil.get_terminal_size()` - +* Change in main.py - Replace the `click.get_terminal_size()` with `shutil.get_terminal_size()` 1.24.3 (2022/01/20) =================== -Bug Fixes: +Bug Fixes ---------- -* Upgrade cli_helpers to workaround Pygments regression. +* Upgrade cli_helpers to workaround Pygments regression. 1.24.2 (2022/01/11) =================== -Bug Fixes: +Bug Fixes ---------- + * Fix autocompletion for more than one JOIN * Fix the status command when connected to TiDB or other servers that don't implement 'Threads\_connected' * Pin pygments version to avoid a breaking change -1.24.1: +1.24.1 ======= -Bug Fixes: +Bug Fixes --------- + * Restore dependency on cryptography for the interactive password prompt -Internal: +Internal --------- -* Deprecate Python mock +* Deprecate Python mock 1.24.0 ====== -Bug Fixes: +Bug Fixes ---------- + * Allow `FileNotFound` exception for SSH config files. * Fix startup error on MySQL < 5.0.22 * Check error code rather than message for Access Denied error * Fix login with ~/.my.cnf files -Features: +Features --------- + * Add `-g` shortcut to option `--login-path`. * Alt-Enter dispatches the command in multi-line mode. -* Allow to pass a file or FIFO path with --password-file when password is not specified or is failing (as suggested in this best-practice https://www.netmeister.org/blog/passing-passwords.html) +* Allow to pass a file or FIFO path with --password-file when password is not specified or is failing (as suggested in this best-practice ) -Internal: +Internal --------- + * Remove unused function is_open_quote() * Use importlib, instead of file links, to locate resources * Test various host-port combinations in command line arguments * Switched from Cryptography to pyaes for decrypting mylogin.cnf - 1.23.2 ====== -Bug Fixes: +Bug Fixes ---------- + * Ensure `--port` is always an int. 1.23.1 ====== -Bug Fixes: +Bug Fixes ---------- + * Allow `--host` without `--port` to make a TCP connection. 1.23.0 ====== -Bug Fixes: +Bug Fixes ---------- + * Fix config file include logic -Features: +Features --------- * Add an option `--init-command` to execute SQL after connecting (Thanks: [KITAGAWA Yasutaka]). @@ -203,10 +1393,11 @@ Features: * Add a special command `\pipe_once` to pipe output to a subprocess. * Add an option `--charset` to set the default charset when connect database. -Bug Fixes: +Bug Fixes ---------- + * Fixed compatibility with sqlparse 0.4 (Thanks: [mtorromeo]). -* Fixed iPython magic (Thanks: [mwcm]). +* Fixed iPython magic (Thanks: [mwcm]). * Send "Connecting to socket" message to the standard error. * Respect empty string for prompt_continuation via `prompt_continuation = ''` in `.myclirc` * Fix \once -o to overwrite output whole, instead of line-by-line. @@ -219,35 +1410,35 @@ Bug Fixes: 1.22.2 ====== -Bug Fixes: +Bug Fixes ---------- -* Make the `pwd` module optional. +* Make the `pwd` module optional. 1.22.1 ====== -Bug Fixes: +Bug Fixes ---------- + * Fix the breaking change introduced in PyMySQL 0.10.0. (Thanks: [Amjith]). -Features: +Features --------- + * Add an option `--ssh-config-host` to read ssh configuration from OpenSSH configuration file. * Add an option `--list-ssh-config` to list ssh configurations. * Add an option `--ssh-config-path` to choose ssh configuration path. -Bug Fixes: +Bug Fixes ---------- * Fix specifying empty password with `--password=''` when config file has a password set (Thanks: [Zach DeCook]). - 1.21.1 ====== - -Bug Fixes: +Bug Fixes ---------- * Fix broken auto-completion for favorite queries (Thanks: [Amjith]). @@ -257,8 +1448,9 @@ Bug Fixes: 1.21.0 ====== -Features: +Features --------- + * Added DSN alias name as a format specifier to the prompt (Thanks: [Georgy Frolov]). * Mark `update` without `where`-clause as destructive query (Thanks: [Klaus Wünschel]). * Added DELIMITER command (Thanks: [Georgy Frolov]) @@ -266,20 +1458,21 @@ Features: * Extend main.is_dropping_database check with create after delete statement. * Search `${XDG_CONFIG_HOME}/mycli/myclirc` after `${HOME}/.myclirc` and before `/etc/myclirc` (Thanks: [Takeshi D. Itoh]) -Bug Fixes: +Bug Fixes ---------- * Allow \o command more than once per session (Thanks: [Georgy Frolov]) * Fixed crash when the query dropping the current database starts with a comment (Thanks: [Georgy Frolov]) -Internal: +Internal --------- + * deprecate python versions 2.7, 3.4, 3.5; support python 3.8 1.20.1 ====== -Bug Fixes: +Bug Fixes ---------- * Fix an error when using login paths with an explicit database name (Thanks: [Thomas Roten]). @@ -287,14 +1480,15 @@ Bug Fixes: 1.20.0 ====== -Features: +Features ---------- + * Auto find alias dsn when `://` not in `database` (Thanks: [QiaoHou Peng]). * Mention URL encoding as escaping technique for special characters in connection DSN (Thanks: [Aljosha Papsch]). * Pressing Alt-Enter will introduce a line break. This is a way to break up the query into multiple lines without switching to multi-line mode. (Thanks: [Amjith Ramanujam]). * Use a generator to stream the output to the pager (Thanks: [Dick Marinus]). -Bug Fixes: +Bug Fixes ---------- * Fix the missing completion for special commands (Thanks: [Amjith Ramanujam]). @@ -304,28 +1498,29 @@ Bug Fixes: * Update `setup.py` to no longer require `sqlparse` to be less than 0.3.0 as that just came out and there are no notable changes. ([VVelox]) * workaround for ConfigObj parsing strings containing "," as lists (Thanks: [Mike Palandra]) -Internal: +Internal --------- + * fix unhashable FormattedText from prompt toolkit in unit tests (Thanks: [Dick Marinus]). 1.19.0 ====== -Internal: +Internal --------- * Add Python 3.7 trove classifier (Thanks: [Thomas Roten]). * Fix pytest in Fedora mock (Thanks: [Dick Marinus]). * Require `prompt_toolkit>=2.0.6` (Thanks: [Dick Marinus]). -Features: +Features --------- * Add Token.Prompt/Continuation (Thanks: [Dick Marinus]). * Don't reconnect when switching databases using use (Thanks: [Angelo Lupo]). * Handle MemoryErrors while trying to pipe in large files and exit gracefully with an error (Thanks: [Amjith Ramanujam]) -Bug Fixes: +Bug Fixes ---------- * Enable Ctrl-Z to suspend the app (Thanks: [Amjith Ramanujam]). @@ -333,12 +1528,12 @@ Bug Fixes: 1.18.2 ====== -Bug Fixes: +Bug Fixes ---------- * Fixes database reconnecting feature (Thanks: [Yang Zou]). -Internal: +Internal --------- * Update Twine version to 1.12.1 (Thanks: [Thomas Roten]). @@ -348,12 +1543,12 @@ Internal: 1.18.1 ====== -Features: +Features --------- * Add Keywords: TINYINT, SMALLINT, MEDIUMINT, INT, BIGINT (Thanks: [QiaoHou Peng]). -Internal: +Internal --------- * Update prompt toolkit (Thanks: [Jonathan Slenders], [Irina Truong], [Dick Marinus]). @@ -361,7 +1556,7 @@ Internal: 1.18.0 ====== -Features: +Features --------- * Display server version in welcome message (Thanks: [Irina Truong]). @@ -372,30 +1567,30 @@ Features: * Add `FROM_UNIXTIME` and `UNIX_TIMESTAMP` to SQLCompleter (Thanks: [QiaoHou Peng]) * Search `${PWD}/.myclirc`, then `${HOME}/.myclirc`, lastly `/etc/myclirc` (Thanks: [QiaoHao Peng]) -Bug Fixes: +Bug Fixes ---------- * When DSN is used, allow overrides from mycli arguments (Thanks: [Dick Marinus]). * A DSN without password should be allowed (Thanks: [Dick Marinus]) -Bug Fixes: +Bug Fixes ---------- * Convert `sql_format` to unicode strings for py27 compatibility (Thanks: [Dick Marinus]). * Fixes mycli compatibility with pbr (Thanks: [Thomas Roten]). * Don't align decimals for `sql_format` (Thanks: [Dick Marinus]). -Internal: +Internal --------- * Use fileinput (Thanks: [Dick Marinus]). * Enable tests for Python 3.7 (Thanks: [Thomas Roten]). * Remove `*.swp` from gitignore (Thanks: [Dick Marinus]). -1.17.0: +1.17.0 ======= -Features: +Features ---------- * Add `CONCAT` to SQLCompleter and remove unused code (Thanks: [caitinggui]) @@ -403,7 +1598,7 @@ Features: * Add option list-dsn (Thanks: [Frederic Aoustin]). * Add verbose option for list-dsn, add tests and clean up code (Thanks: [Dick Marinus]). -Bug Fixes: +Bug Fixes ---------- * Add enable_pager to the config file (Thanks: [Frederic Aoustin]). @@ -415,51 +1610,50 @@ Bug Fixes: * Quote CSV fields (Thanks: [Thomas Roten]). * Fix `thanks_picker` (Thanks: [Dick Marinus]). -Internal: +Internal --------- * Refactor Destructive Warning behave tests (Thanks: [Dick Marinus]). - -1.16.0: +1.16.0 ======= -Features: +Features --------- * Add DSN aliases to the config file (Thanks: [Frederic Aoustin]). -Bug Fixes: +Bug Fixes ---------- * Do not try to connect to a unix socket on Windows (Thanks: [Thomas Roten]). -1.15.0: +1.15.0 ======= -Features: +Features --------- * Add sql-update/insert output format. (Thanks: [Dick Marinus]). * Also complete aliases in WHERE. (Thanks: [Dick Marinus]). -1.14.0: +1.14.0 ======= -Features: +Features --------- * Add `watch [seconds] query` command to repeat a query every [seconds] seconds (by default 5). (Thanks: [David Caro](https://github.com/Terseus)) * Default to unix socket connection if host and port are unspecified. This simplifies authentication on some systems and matches mysql behaviour. * Add support for positional parameters to favorite queries. (Thanks: [Scrappy Soft](https://github.com/scrappysoft)) -Bug Fixes: +Bug Fixes ---------- * Fix source command for script in current working directory. (Thanks: [Dick Marinus]). * Fix issue where the `tee` command did not work on Python 2.7 (Thanks: [Thomas Roten]). -Internal Changes: +Internal Changes ----------------- * Drop support for Python 3.3 (Thanks: [Thomas Roten]). @@ -467,64 +1661,63 @@ Internal Changes: * Make tests more compatible between different build environments. (Thanks: [David Caro]) * Merge `_on_completions_refreshed` and `_swap_completer_objects` functions (Thanks: [Dick Marinus]). -1.13.1: +1.13.1 ======= -Bug Fixes: +Bug Fixes ---------- * Fix keyword completion suggestion for `SHOW` (Thanks: [Thomas Roten]). * Prevent mycli from crashing when failing to read login path file (Thanks: [Thomas Roten]). -Internal Changes: +Internal Changes ----------------- * Make tests ignore user config files (Thanks: [Thomas Roten]). -1.13.0: +1.13.0 ======= -Features: +Features --------- * Add file name completion for source command (issue #500). (Thanks: [Irina Truong]). -Bug Fixes: +Bug Fixes ---------- * Fix UnicodeEncodeError when editing sql command in external editor (Thanks: Klaus Wünschel). * Fix MySQL4 version comment retrieval (Thanks: [François Pietka]) * Fix error that occurred when outputting JSON and NULL data (Thanks: [Thomas Roten]). -1.12.1: +1.12.1 ======= -Bug Fixes: +Bug Fixes ---------- * Prevent missing MySQL help database from causing errors in completions (Thanks: [Thomas Roten]). * Fix mycli from crashing with small terminal windows under Python 2 (Thanks: [Thomas Roten]). * Prevent an error from displaying when you drop the current database (Thanks: [Thomas Roten]). -Internal Changes: +Internal Changes ----------------- * Use less memory when formatting results for display (Thanks: [Dick Marinus]). * Preliminary work for a future change in outputting results that uses less memory (Thanks: [Dick Marinus]). -1.12.0: +1.12.0 ======= -Features: +Features --------- * Add fish-style auto-suggestion from history. (Thanks: [Amjith Ramanujam]) - -1.11.0: +1.11.0 ======= -Features: +Features --------- * Handle reserved space for completion menu better in small windows. (Thanks: [Thomas Roten]). @@ -538,7 +1731,7 @@ Features: * Add colored/styled headers and odd/even rows (Thanks: [Thomas Roten]). * Keyword completion casing (upper/lower/auto) (Thanks: [Irina Truong]). -Bug Fixes: +Bug Fixes ---------- * Fixed incorrect timekeeping when running queries from a file. (Thanks: [Thomas Roten]). @@ -549,7 +1742,7 @@ Bug Fixes: * Support tilde user directory for output file names (Thanks: [Thomas Roten]). * Auto vertical output is a little bit better at its calculations (Thanks: [Thomas Roten]). -Internal Changes: +Internal Changes ----------------- * Rename tests/ to test/. (Thanks: [Dick Marinus]). @@ -568,10 +1761,10 @@ Internal Changes: * Add missing @dbtest to tests (Thanks: [Dick Marinus]). * Standardizes punctuation/grammar for help strings (Thanks: [Thomas Roten]). -1.10.0: +1.10.0 ======= -Features: +Features --------- * Add ability to specify alternative myclirc file. (Thanks: [Dick Marinus]). @@ -579,7 +1772,7 @@ Features: Ramanujam], [Dick Marinus], [Thomas Roten]). * Add logic to shorten the default prompt if it becomes too long once generated. (Thanks: [John Sterling]). -Bug Fixes: +Bug Fixes ---------- * Fix external editor bug (issue #377). (Thanks: [Irina Truong]). @@ -590,7 +1783,7 @@ Bug Fixes: (Thanks: [Thomas Roten]). * Use pymysql default conversions (issue #375). (Thanks: [Dick Marinus]). -Internal Changes: +Internal Changes ----------------- * Upload mycli distributions in a safer manner (using twine). (Thanks: [Thomas @@ -599,10 +1792,10 @@ Internal Changes: * Run pep8 checks in travis (Thanks: [Irina Truong]). * Remove temporary hack for sqlparse (Thanks: [Dick Marinus]). -1.9.0: +1.9.0 ====== -Features: +Features --------- * Add tee/notee commands for outputing results to a file. (Thanks: [Dick Marinus]). @@ -613,7 +1806,7 @@ Features: * Add `auto_vertical_output` config to myclirc. (Thanks: [Matheus Rosa]). * Improve Fedora install instructions. (Thanks: [Dick Marinus]). -Bug Fixes: +Bug Fixes ---------- * Fix crashes occuring from commands starting with #. (Thanks: [Zhidong]). @@ -625,7 +1818,7 @@ Bug Fixes: * Kill running query when interrupted via Ctrl-C. (Thanks: [chainkite]). * Read the `smart_completion` config from myclirc. (Thanks: [Thomas Roten]). -Internal Changes: +Internal Changes ----------------- * Improve handling of test database credentials. (Thanks: [Dick Marinus]). @@ -634,25 +1827,27 @@ Internal Changes: * Swap pycrypto dependency for pycryptodome. (Thanks: [Michał Górny]). * Bump sqlparse version so pgcli and mycli can be installed together. (Thanks: [darikg]). -1.8.1: +1.8.1 ====== -Bug Fixes: +Bug Fixes ---------- + * Remove duplicate listing of DISTINCT keyword. (Thanks: [Amjith Ramanujam]). * Add an try/except for AS keyword crash. (Thanks: [Amjith Ramanujam]). * Support python-sqlparse 0.2. (Thanks: [Dick Marinus]). * Fallback to the raw object for invalid time values. (Thanks: [Amjith Ramanujam]). * Reset the show items when completion is refreshed. (Thanks: [Amjith Ramanujam]). -Internal Changes: +Internal Changes ----------------- + * Make the dependency of sqlparse slightly more liberal. (Thanks: [Amjith Ramanujam]). -1.8.0: +1.8.0 ====== -Features: +Features --------- * Add support for --execute/-e commandline arg. (Thanks: [Matheus Rosa]). @@ -661,17 +1856,17 @@ Features: * Add `prompt_continuation` config option to allow configuring the continuation prompt for multi-line queries. (Thanks: [Scrappy Soft]). * Display login-path instead of host in prompt. (Thanks: [Irina Truong]). -Bug Fixes: +Bug Fixes ---------- * Pin sqlparse to version 0.1.19 since the new version is breaking completion. (Thanks: [Amjith Ramanujam]). * Remove unsupported keywords. (Thanks: [Matheus Rosa]). * Fix completion suggestion inside functions with operands. (Thanks: [Irina Truong]). -1.7.0: +1.7.0 ====== -Features: +Features --------- * Add stdin batch mode. (Thanks: [Thomas Roten]). @@ -680,20 +1875,20 @@ Features: * Update features list in README.md. (Thanks: [Matheus Rosa]). * Remove extra \n in features list in README.md. (Thanks: [Matheus Rosa]). -Bug Fixes: +Bug Fixes ---------- * Enable history search via . (Thanks: [Amjith Ramanujam]). -Internal Changes: +Internal Changes ----------------- * Upgrade `prompt_toolkit` to 1.0.0. (Thanks: [Jonathan Slenders]) -1.6.0: +1.6.0 ====== -Features: +Features --------- * Change continuation prompt for multi-line mode to match default mysql. @@ -706,14 +1901,14 @@ Features: * Add support for `nopager` and `\n` to turn off the pager. (Thanks: [Thomas Roten]). * Add support for `--local-infile` command-line option. (Thanks: [Thomas Roten]). -Bug Fixes: +Bug Fixes ---------- * Remove -S from `less` option which was clobbering the scroll back in history. (Thanks: [Thomas Roten]). * Make system command work with Python 3. (Thanks: [Thomas Roten]). * Support \G terminator for \f queries. (Thanks: [Terseus]). -Internal Changes: +Internal Changes ----------------- * Upgrade `prompt_toolkit` to 0.60. @@ -724,26 +1919,26 @@ Internal Changes: * Capture warnings to log file. (Thanks: [Mikhail Borisov]). * Make `syntax_style` a tiny bit more intuitive. (Thanks: [Phil Cohen]). -1.5.2: +1.5.2 ====== -Bug Fixes: +Bug Fixes ---------- * Protect against port number being None when no port is specified in command line. -1.5.1: +1.5.1 ====== -Bug Fixes: +Bug Fixes ---------- * Cast the value of port read from my.cnf to int. -1.5.0: +1.5.0 ====== -Features: +Features --------- * Make a config option to enable `audit_log`. (Thanks: [Matheus Rosa]). @@ -752,21 +1947,25 @@ Features: * Register the special command `prompt` with the `\R` as alias. (Thanks: [Matheus Rosa]). Users can now change the mysql prompt at runtime using `prompt` command. eg: + ``` mycli> prompt \u@\h> Changed prompt format to \u@\h> Time: 0.001s amjith@localhost> ``` + * Perform completion refresh in a background thread. Now mycli can handle databases with thousands of tables without blocking. * Add support for `system` command. (Thanks: [Matheus Rosa]). Users can now run a system command from within mycli as follows: + ``` amjith@localhost:(none)>system cat tmp.sql select 1; select * from django_migrations; ``` + * Caught and hexed binary fields in MySQL. (Thanks: [Daniel West]). Geometric fields stored in a database will be displayed as hexed strings. * Treat enter key as tab when the suggestion menu is open. (Thanks: [Matheus Rosa]) @@ -776,7 +1975,7 @@ Features: * Add TRANSACTION related keywords. * Treat DESC and EXPLAIN as DESCRIBE. (Thanks: [spacewander]). -Bug Fixes: +Bug Fixes ---------- * Fix the removal of whitespace from table output. @@ -784,23 +1983,25 @@ Bug Fixes: * Fix the incorrect reporting of command time. * Add type validation for port argument. (Thanks [Matheus Rosa]) -Internal Changes: +Internal Changes ----------------- + * Make pycrypto optional and only install it in \*nix systems. (Thanks: [Irina Truong]). * Add badge for PyPI version to README. (Thanks: [Shoma Suzuki]). * Updated release script with a --dry-run and --confirm-steps option. (Thanks: [Irina Truong]). * Adds support for PyMySQL 0.6.2 and above. This is useful for debian package builders. (Thanks: [Thomas Roten]). * Disable click warning. -1.4.0: +1.4.0 ====== -Features: +Features --------- * Add `source` command. This allows running sql statement from a file. eg: + ``` mycli> source filename.sql ``` @@ -819,29 +2020,33 @@ Features: Multi-line queries are automatically indented. -Bug Fixes: +Bug Fixes ---------- * Fix keyword completion after the `WHERE` clause. * Add `\g` and `\G` as valid query terminators. Previously in multi-line mode ending a query with a `\G` wouldn't run the query. This is now fixed. -1.3.0: +1.3.0 ====== -Features: +Features --------- + * Add a new special command (\T) to change the table format on the fly. (Thanks: [Jonathan Bruno](https://github.com/brewneaux)) eg: + ``` mycli> \T tsv ``` + * Add `--defaults-group-suffix` to the command line. This lets the user specify - a group to use in the my.cnf files. (Thanks: [Irina Truong](http://github.com/j-bennet)) + a group to use in the my.cnf files. (Thanks: [Irina Truong](https://github.com/j-bennet)) In the my.cnf file a user can specify credentials for different databases and invoke mycli with the group name to use the appropriate credentials. eg: + ``` # my.cnf [client] @@ -863,79 +2068,77 @@ Features: * Make `-p` and `--password` take the password in commandline. This makes mycli a drop in replacement for mysql. -1.2.0: +1.2.0 ====== -Features: +Features --------- * Add support for wider completion menus in the config file. Add `wider_completion_menu = True` in the config file (~/.myclirc) to enable this feature. -Bug Fixes: +Bug Fixes --------- * Prevent Ctrl-C from quitting mycli while the pager is active. * Refresh auto-completions after the database is changed via a CONNECT command. -Internal Changes: +Internal Changes ----------------- * Upgrade `prompt_toolkit` dependency version to 0.45. * Added Travis CI to run the tests automatically. -1.1.1: +1.1.1 ====== -Bug Fixes: +Bug Fixes ---------- * Change dictonary comprehension used in mycnf reader to list comprehension to make it compatible with Python 2.6. - -1.1.0: +1.1.0 ====== -Features: +Features --------- * Fuzzy completion is now case-insensitive. (Thanks: [bjarnagin](https://github.com/bjarnagin)) * Added new-line (`\n`) to the list of special characters to use in prompt. (Thanks: [brewneaux](https://github.com/brewneaux)) -* Honor the `pager` setting in my.cnf files. (Thanks: [Irina Truong](http://github.com/j-bennet)) +* Honor the `pager` setting in my.cnf files. (Thanks: [Irina Truong](https://github.com/j-bennet)) -Bug Fixes: +Bug Fixes ---------- * Fix a crashing bug in completion engine for cross joins. * Make `` value consistent between tabular and vertical output. -Internal Changes: +Internal Changes ----------------- * Changed pymysql version to be greater than 0.6.6. * Upgrade `prompt_toolkit` version to 0.42. (Thanks: [Yasuhiro Matsumoto](https://github.com/mattn)) * Removed the explicit dependency on six. -2015/06/10: +2015/06/10 =========== -Features: +Features --------- * Customizable prompt. (Thanks [Steve Robbins](https://github.com/steverobbins)) * Make `\G` formatting to behave more like mysql. -Bug Fixes: +Bug Fixes ---------- * Formatting issue in \G for really long column values. - -2015/06/07: +2015/06/07 =========== -Features: +Features --------- * Upgrade `prompt_toolkit` to 0.38. This improves the performance of pasting long queries. @@ -946,18 +2149,17 @@ Features: * Add fuzzy completion for table names and column names. * Automatically reconnect when connection is lost to the database. -Bug Fixes: +Bug Fixes ---------- * Fix a bug with reconnect failure. * Fix the issue with `use` command not changing the prompt. * Fix the issue where `\\r` shortcut was not recognized. - 2015/05/24 ========== -Features: +Features --------- * Add support for connecting via socket. @@ -966,25 +2168,22 @@ Features: * Made the timing of sql statements human friendly. * Automatically prompt for a password if needed. -Bug Fixes: +Bug Fixes ---------- + * Fixed the installation issues with PyMySQL dependency on case-sensitive file systems. [Amjith Ramanujam]: https://blog.amjith.com [Artem Bezsmertnyi]: https://github.com/mrdeathless [BuonOmo]: https://github.com/BuonOmo -[Carlos Afonso]: https://github.com/afonsocarlos -[Casper Langemeijer]: https://github.com/langemeijer -[Daniel West]: http://github.com/danieljwest +[Daniel West]: https://github.com/danieljwest [Dick Marinus]: https://github.com/meeuw [François Pietka]: https://github.com/fpietka [Frederic Aoustin]: https://github.com/fraoustin [Georgy Frolov]: https://github.com/pasenor [Irina Truong]: https://github.com/j-bennet [Jonathan Slenders]: https://github.com/jonathanslenders -[Kacper Kwapisz]: https://github.com/KKKas [laixintao]: https://github.com/laixintao -[Lennart Weller]: https://github.com/lhw [Martijn Engler]: https://github.com/martijnengler [Matheus Rosa]: https://github.com/mdsrosa [Mikhail Borisov]: https://github.com/borman @@ -996,7 +2195,6 @@ Bug Fixes: [spacewander]: https://github.com/spacewander [Terseus]: https://github.com/Terseus [Thomas Roten]: https://github.com/tsroten -[William GARCIA]: https://github.com/willgarcia [xeron]: https://github.com/xeron [Zach DeCook]: https://zachdecook.com [Will Wang]: https://github.com/willww64 diff --git a/doc/key_bindings.rst b/doc/key_bindings.rst index e3ebcd9b..9673921b 100644 --- a/doc/key_bindings.rst +++ b/doc/key_bindings.rst @@ -6,6 +6,12 @@ Most key bindings are simply inherited from `prompt-toolkit \llm "Capital of India?" +-- Answer text from the model... +-- ```sql +-- SELECT ...; +-- ``` +-- Your prompt is prefilled with the SQL above. +``` + +You can now hit Enter to run, or edit the query first. + +--- + +## What Context Is Sent + +When you ask a plain question via `\llm "..."`, mycli: +- Sends your question. +- Adds your current database schema: table names with column types. +- Adds one sample row (if available) from each table. + +This helps the model propose SQL that fits your schema. Follow‑ups using `-c` continue the same conversation and do not re-send the DB context (see “Continue Conversation (-c)”). + +Note: Context is gathered from the current connection. If you are not connected, using contextual mode will fail — connect first. + +--- + +## Using `llm` Subcommands from mycli + +You can run any `llm` CLI subcommand by prefixing it with `\llm` inside mycli. Examples: + +- List models: + ```text + \llm models + ``` +- Set the default model: + ```text + \llm models default gpt-5 + ``` +- Set provider API key: + ```text + \llm keys set openai + ``` +- Install a plugin (e.g., local models via Ollama): + ```text + \llm install llm-ollama + ``` + After installing or uninstalling plugins, mycli will restart to pick up new commands. + +Tab completion works for `\llm` subcommands, and even for model IDs under `models default`. + +Aside: for using local models. + +--- + +## Ask Questions With DB Context (default) + +Ask your question in quotes. mycli sends database context and extracts a SQL block if present. + +```text +World> \llm "Most visited urls?" +``` + +Behavior: +- Response is printed in the output pane. +- If the response contains a ```sql fenced block, mycli extracts the SQL and pre-fills it at your prompt. + +--- + +## Continue Conversation (-c) + +Use `-c` to ask a follow‑up that continues the previous conversation with the model. This does not re-send the DB context; it relies on the ongoing thread. + +```text +World> \llm "Top 10 customers by spend" +-- model returns analysis and a ```sql block; SQL is prefilled +World> \llm -c "Now include each customer's email and order count" +``` + +Behavior: +- Continues the last conversation in the `llm` history. +- Database context is not re-sent on follow‑ups. +- If the response includes a ```sql block, the SQL is pre-filled at your prompt. + + +--- + +## Examples + +- List available models: + ```text + World> \llm models + ``` + +- Change default model: + ```text + World> \llm models default llama3 + ``` + +- Set API key (for providers that require it): + ```text + World> \llm keys set openai + ``` + +- Ask a question with context: + ```text + World> \llm "Capital of India?" + ``` + +- Use a local model (after installing a plugin such as `llm-ollama`): + ```text + World> \llm install llm-ollama + World> \llm models default llama3 + World> \llm "Top 10 customers by spend" + ``` + +See: for details. + +--- + +## Customize the Prompt Template + +mycli uses a saved `llm` template named `mycli-llm-template` for contextual questions. You can view or edit it: + +```text +World> \llm templates edit mycli-llm-template +``` + +Tip: After first use, mycli ensures this template exists. To just view it without editing, use: + +```text +World> \llm templates show mycli-llm-template +``` + +--- + +## Troubleshooting + +- No SQL pre-fill: Ensure the model’s response includes a ```sql fenced block. The built‑in prompt encourages this, but some models may omit it; try asking the model to include SQL in a ```sql block. +- Not connected to a database: Contextual questions require a live connection. Connect first. Follow‑ups with `-c` only help after a successful contextual call. +- Plugin changes not recognized: After `\llm install` or `\llm uninstall`, mycli restarts automatically to load new commands. +- Provider/API issues: Use `\llm keys list` and `\llm keys set ` to check credentials. Use `\llm models` to confirm available models. + +--- + +## Notes and Safety + +- Data sent: Contextual questions send schema (table/column names and types) and a single sample row per table. Review your data sensitivity policies before using remote models; prefer local models (such as ollama) if needed. +- Help: Running `\llm` with no arguments shows a short usage message. + +## Turning Off LLM Support + +To turn off LLM support even when the `llm` dependency is installed, set the `MYCLI_LLM_OFF` environment variable: +```bash +export MYCLI_LLM_OFF=1 +``` + +This may be desirable for faster startup times. + + +--- + +## Learn More + +- `llm` project docs: https://llm.datasette.io/ +- `llm` plugin directory: https://llm.datasette.io/en/stable/plugins/directory.html diff --git a/screenshots/main.gif b/doc/screenshots/main.gif similarity index 100% rename from screenshots/main.gif rename to doc/screenshots/main.gif diff --git a/screenshots/tables.png b/doc/screenshots/tables.png similarity index 100% rename from screenshots/tables.png rename to doc/screenshots/tables.png diff --git a/mycli/AUTHORS b/mycli/AUTHORS index d5a9ce08..08823bd2 100644 --- a/mycli/AUTHORS +++ b/mycli/AUTHORS @@ -1,3 +1,8 @@ +Project Lead: +------------- + + * Roland Walker + Core Developers: ---------------- @@ -7,6 +12,7 @@ Core Developers: * Darik Gamble * Dick Marinus * Amjith Ramanujam + * Scott Nemes Contributors: ------------- @@ -15,6 +21,7 @@ Contributors: * Abirami P * Adam Chainz * Aljosha Papsch + * Allrob * Andy Teijelo Pérez * Angelo Lupo * Artem Bezsmertnyi @@ -32,6 +39,7 @@ Contributors: * Daniel West * Daniël van Eeden * Fabrizio Gennari + * FatBoyXPC * François Pietka * Frederic Aoustin * Georgy Frolov @@ -59,13 +67,13 @@ Contributors: * Michał Górny * Mike Palandra * Mikhail Borisov + * Miodrag Tokić * Morgan Mitchell * mrdeathless * Nathan Huang * Nicolas Palumbo * Phil Cohen * QiaoHou Peng - * Roland Walker * Ryan Smith * Scrappy Soft * Seamile @@ -98,6 +106,15 @@ Contributors: * Houston Wong * Mohamed Rezk * Ryosuke Kazami + * Cornel Cruceru + * Sherlock Holo + * keltaklo + * 924060929 + * tmijieux + * Angelino Storm + * Abhay Kumar + * yurenchen000 + * Linuxdazhao Created by: diff --git a/mycli/SPONSORS b/mycli/SPONSORS index 81b0904c..e3c95945 100644 --- a/mycli/SPONSORS +++ b/mycli/SPONSORS @@ -29,3 +29,7 @@ Many thanks to the following Kickstarter backers. * Ted Pennings * Chris Anderton * Jonathan Slenders + +# Other Donors + +* OpenAI diff --git a/mycli/TIPS b/mycli/TIPS new file mode 100644 index 00000000..c7a65955 --- /dev/null +++ b/mycli/TIPS @@ -0,0 +1,291 @@ +### +### CLI arguments +### + +check your ~/.myclirc settings using the --checkup flag! + +list your aliased DSNs with the --list-dsn flag! + +log every query and result with the --logfile option! + +the --checkpoint option helps track successful queries in batch mode! + +the --format option helps set the output format in batch mode! + +the --throttle option helps slow down queries in batch mode! + +the --password-file option can be used with a FIFO to avoid saving creds to a file! + +the --character-set option sets the character set for a single session! + +the --unbuffered flag can save memory when in batch mode! + +--use-keyring=true lets you access the system keyring for passwords! + +--use-keyring=reset resets a password saved to the system keyring! + +the --myclirc option can change the config file location for a single session! + +the --execute option lets you execute a single line of SQL! + +the --auto-vertical-output flag lets you automatically switch to vertical output! + +the --show-warnings flag turns on warnings from the MySQL server! + +the --no-warn flag turns off warnings before running a destructive query! + +the --init-command option lets you execute initialization SQL before a session! + +the --login-path option lets you work with login-path files! + +--keepalive-ticks= sets keepalive pings for a single session! + +### +### commands +### + +interact with an LLM using the \llm command! + +copy a query to the clipboard using \clip at the end of the query! + +\dt lists tables; \dt describes
! + +edit a query in an external editor using \edit! + +edit a query in an external editor using \edit ! + +\f lists favorite queries; \f executes a favorite! + +\fs saves a favorite query! + +\fd deletes a saved favorite query! + +\l lists databases! + +\once appends the next result to ! + +\| sends the next result to a subprocess! + +\t toggles timing of commands! + +\r or "connect" reconnects to the server! + +\delimiter changes the SQL delimiter! + +\q, "quit", or "exit" exits from the prompt! + +\? or "help" for help! + +"help " for help on SQL keywords! + +\n or "nopager" to disable the pager! + +use "tee"/"notee" to write/stop-writing results to a output file! + +\W or "warnings" enables automatic warnings display! + +\w or "nowarnings" disables automatic warnings display! + +\P or "pager" sets the pager. Try "pager less"! + +\R or "prompt" changes the prompt format! + +\Tr or "redirectformat" changes the table format for redirects! + +\# or "rehash" refreshes autocompletions! + +\. or "source" executes queries from a file! + +\s or "status" requests status information from the server! + +use "system " to execute a shell command! + +\T or "tableformat" changes the interactive table format! + +\u or "use" changes to a new database! + +the "watch" command executes a query every N seconds! + +use \bug to file a bug on GitHub! + +### +### environment variables +### + +run "export VISUAL='code --wait'" in your shell to \edit queries using VS Code! + +set environment variable MYCLI_LLM_OFF to skip loading LLM libraries! + +set environment variable MYCLI_HISTFILE to relocate the history file! + +set environment variable MYSQL_PWD to set a default password! + +set environment variable MYSQL_HOST to set a default host! + +set environment variable MYSQL_TCP_PORT to set a default port! + +set environment variable MYSQL_USER to set a default username! + +set environment variable MYSQL_UNIX_SOCKET to set a default socket! + +set environment variable MYSQL_DSN to set a default DSN! + +### +### general +### + +display query output vertically using \G at the end of a query! + +run SQL scripts in batch mode using the standard input! + +### +### keystrokes +### + +edit a query in an external editor using keystrokes control-x + control-e! + +open a documentation browser using keystroke F1! + +toggle smart completion using keystroke F2! + +toggle multi-line mode using keystroke F3! + +toggle vi mode using keystroke F4! + +complete at cursor using the tab key! + +summon completion candidates using control-space! + +control-space works well with "min_completion_trigger" in ~/.myclirc! + +prettify a query using keystrokes control-x + p! + +un-prettify a query using keystrokes control-x + u! + +insert the current date using keystrokes control-o + d! + +insert the quoted current date using keystrokes control-o + control-d! + +insert the current datetime using keystrokes control-o + t! + +insert the quoted current date using keystrokes control-o + control-t! + +search query history using keystroke control-r! + +use keystroke control-g to cancel completion popups! + +use keystroke right-arrow to accept a full-line suggestion from your history! + +cancel history search using keystrokes Escape or control-g! + +uppercase a word using keystroke alt-u! + +lowercase a word using keystroke alt-l! + +collapse multiple spaces using keystroke alt-\! + +undo using keystroke control-_ or control-x + control-u! + +ditto the last argument of the previous command with keystroke alt-.! + +ditto the last argument of the previous command with keystroke alt-_! + +turn the current query into a comment with keystroke alt-#! + +jump forward to a character with keystroke control-]! + +jump backward to a character with keystroke alt-control-]! + +insert all completions with keystroke alt-*! + +in multi-line mode, keystroke alt-Enter dispatches the query! + +keystroke control-q + control-j inserts a newline without dispatching the query! + +### +### myclirc options +### + +set "less_chatty = True" in ~/.myclirc to turn off these tips! + +set a fancy table format like "table_format = psql_unicode" in ~/.myclirc! + +change the string for NULLs with "null_string = " in ~/.myclirc! + +choose a color theme with "syntax_style" in ~/.myclirc! + +design a prompt with the "prompt" option in ~/.myclirc! + +turn off multi-line prompt indentation with "prompt_continuation = ''" in ~/.myclirc! + +save passwords in the system keyring with "use_keyring" in ~/.myclirc! + +enable SHOW WARNINGS with "show warnings" in ~/.myclirc! + +turn off smart completions with "smart_completion" in ~/.myclirc! + +turn on multi-line mode with "multi_line" in ~/.myclirc! + +turn off destructive warnings with "destructive_warning" in ~/.myclirc! + +control destructive warnings with "destructive_keywords" in ~/.myclirc! + +move the history file locattion with "history_file" in ~/.myclirc! + +enable an audit log with "audit_log" in ~/.myclirc! + +disable timing of SQL statements with "timiing" in ~/.myclirc! + +disable display of SQL when running a favorite with "show_favorite_query" in ~/.myclirc! + +notify after a long query by setting "beep_after_seconds" in ~/.myclirc! + +control alignment with "numeric_alignment" in ~/.myclirc! + +control binary value display with "binary_display" in ~/.myclirc! + +set vi key bindings with "key_bindings" in ~/.myclirc! + +show more suggestions with "wider_completion_menu" in ~/.myclirc! + +use the host alias in the prompt with "login_path_as_host" in ~/.myclirc! + +auto-display wide results vertically with "auto_vertical_output" in ~/.myclirc! + +control keyword casing in completions using "keyword_casing" in ~/.myclirc! + +disable pager on startup using "enable_pager" in ~/.myclirc! + +choose a pager command with "pager" in ~/.myclirc! + +customize colors using the "[colors]" section in ~/.myclirc! + +customize LLM commands using the "[llm]" section in ~/.myclirc! + +customize history search using "control_r" in ~/.myclirc! + +edit favorite queries directly using the "[favorite_queries]" section in ~/.myclirc! + +set up initial commands using the "[init-commands]" section in ~/.myclirc! + +create DSN shortcuts using the "[alias_dsn]" section in ~/.myclirc! + +set up per-DSN initial commands using the "[alias_dsn.init-commands]" section in ~/.myclirc! + +set up connection defaults using the "[connection]" section in ~/.myclirc! + +use "min_completion_trigger" in ~/.myclirc to defer completions! + +colorize search previews with "highlight_preview" in ~/.myclirc! + +### +### redirection +### + +redirect query output to a shell command with "$| "! + +redirect query output to a file with "$> "! + +append query output to a file with "$>> "! + +run a command after shell redirects with "post_redirect_command" in ~/.myclirc! diff --git a/mycli/__init__.py b/mycli/__init__.py index b5476c14..699df6c0 100644 --- a/mycli/__init__.py +++ b/mycli/__init__.py @@ -1 +1,3 @@ -__version__ = '1.27.2' +import importlib.metadata + +__version__: str = importlib.metadata.version("mycli") diff --git a/mycli/app_state.py b/mycli/app_state.py new file mode 100644 index 00000000..0aaad28a --- /dev/null +++ b/mycli/app_state.py @@ -0,0 +1,107 @@ +from __future__ import annotations + +from collections import defaultdict +import re +from typing import TYPE_CHECKING, Any + +from configobj import ConfigObj + +from mycli.config import str_to_bool, strip_matching_quotes + +if TYPE_CHECKING: + from mycli.main import MyCli + + +def normalize_ssl_mode(config: ConfigObj) -> tuple[str | None, str | None]: + ssl_mode = config['main'].get('ssl_mode', None) or config['connection'].get('default_ssl_mode', None) + if ssl_mode not in ('auto', 'on', 'off', None): + return None, f'Invalid config option provided for ssl_mode ({ssl_mode}); ignoring.' + return ssl_mode, None + + +def ensure_my_cnf_sections(my_cnf: ConfigObj) -> None: + if not my_cnf.get('client'): + my_cnf['client'] = {} + if not my_cnf.get('mysqld'): + my_cnf['mysqld'] = {} + + +def configure_prompt_state( + mycli: MyCli, + config: ConfigObj, + prompt: str | None, + prompt_cnf: str | None, + toolbar_format: str | None, +) -> None: + mycli.prompt_format = prompt or prompt_cnf or config['main']['prompt'] or mycli.default_prompt + mycli.prompt_lines = 0 + mycli.multiline_continuation_char = config['main']['prompt_continuation'] + mycli.toolbar_format = toolbar_format or config['main']['toolbar'] + mycli.terminal_tab_title_format = config['main']['terminal_tab_title'] + mycli.terminal_window_title_format = config['main']['terminal_window_title'] + mycli.multiplex_window_title_format = config['main']['multiplex_window_title'] + mycli.multiplex_pane_title_format = config['main']['multiplex_pane_title'] + + +def destructive_keywords_from_config(config: ConfigObj) -> list[str]: + keywords = config['main'].get('destructive_keywords', 'DROP SHUTDOWN DELETE TRUNCATE ALTER UPDATE') + return [keyword for keyword in keywords.split(' ') if keyword] + + +def llm_prompt_truncation(config: ConfigObj) -> tuple[int, int]: + if 'llm' in config and re.match(r'^\d+$', config['llm'].get('prompt_field_truncate', '')): + field_truncate = int(config['llm'].get('prompt_field_truncate')) + else: + field_truncate = 0 + if 'llm' in config and re.match(r'^\d+$', config['llm'].get('prompt_section_truncate', '')): + section_truncate = int(config['llm'].get('prompt_section_truncate')) + else: + section_truncate = 0 + return field_truncate, section_truncate + + +class AppStateMixin: + defaults_suffix: str | None + login_path: str | None + + def read_my_cnf(self, cnf: ConfigObj, keys: list[str]) -> dict[str, Any]: + sections = ['client', 'mysqld'] + key_transformations = { + 'mysqld': { + 'socket': 'default_socket', + 'port': 'default_port', + 'user': 'default_user', + }, + } + + if self.login_path and self.login_path != 'client': + sections.append(self.login_path) + + if self.defaults_suffix: + sections.extend([sect + self.defaults_suffix for sect in sections]) + + configuration: dict[str, Any] = defaultdict(lambda: None) + for key in keys: + for section in cnf: + if section not in sections or key not in cnf[section]: + continue + new_key = key_transformations.get(section, {}).get(key) or key + configuration[new_key] = strip_matching_quotes(cnf[section][key]) + + return configuration + + def merge_ssl_with_cnf(self, ssl: dict[str, Any], cnf: dict[str, Any]) -> dict[str, Any]: + merged = {} + merged.update(ssl) + prefix = 'ssl-' + for key, value in cnf.items(): + if not key.startswith(prefix): + continue + if value is None: + continue + if key == 'ssl-verify-server-cert': + merged['check_hostname'] = str_to_bool(value) + else: + merged[key[len(prefix) :]] = value + + return merged diff --git a/mycli/cli_args.py b/mycli/cli_args.py new file mode 100644 index 00000000..bf95f59d --- /dev/null +++ b/mycli/cli_args.py @@ -0,0 +1,373 @@ +from __future__ import annotations + +from dataclasses import dataclass +from io import TextIOWrapper +import os +import sys +from typing import Callable + +import click +import clickdc + +EMPTY_PASSWORD_FLAG_SENTINEL = -1 +DEFAULT_PROMPT = "\\t \\u@\\h:\\d> " + + +class IntOrStringClickParamType(click.ParamType): + name = 'text' # display as TEXT in helpdoc + + def convert(self, value, param, ctx): + if isinstance(value, int): + return value + elif isinstance(value, str): + return value + elif value is None: + return value + else: + self.fail('Not a valid password string', param, ctx) + + +INT_OR_STRING_CLICK_TYPE = IntOrStringClickParamType() + + +@dataclass(slots=True) +class CliArgs: + database: str | None = clickdc.argument( + type=str, + default=None, + nargs=1, + ) + host: str | None = clickdc.option( + '-h', + '--hostname', + 'host', + type=str, + envvar='MYSQL_HOST', + help='Host address of the database.', + ) + port: int | None = clickdc.option( + '-P', + type=int, + envvar='MYSQL_TCP_PORT', + help='Port number to use for connection. Honors $MYSQL_TCP_PORT.', + ) + user: str | None = clickdc.option( + '-u', + '--user', + '--username', + 'user', + type=str, + envvar='MYSQL_USER', + help='User name to connect to the database.', + ) + socket: str | None = clickdc.option( + '-S', + type=str, + envvar='MYSQL_UNIX_SOCKET', + help='The socket file to use for connection.', + ) + password: int | str | None = clickdc.option( + '-p', + '--pass', + '--password', + 'password', + type=INT_OR_STRING_CLICK_TYPE, + is_flag=False, + flag_value=EMPTY_PASSWORD_FLAG_SENTINEL, + help='Prompt for (or pass in cleartext) the password to connect to the database.', + ) + password_file: str | None = clickdc.option( + type=click.Path(), + help='File or FIFO path containing the password to connect to the db if not specified otherwise.', + ) + ssh_user: str | None = clickdc.option( + type=str, + help='User name to connect to ssh server.', + ) + ssh_host: str | None = clickdc.option( + type=str, + help='Host name to connect to ssh server.', + ) + ssh_port: int = clickdc.option( + type=int, + default=22, + help='Port to connect to ssh server.', + ) + ssh_password: str | None = clickdc.option( + type=str, + help='Password to connect to ssh server.', + ) + ssh_key_filename: str | None = clickdc.option( + type=str, + help='Private key filename (identify file) for the ssh connection.', + ) + ssh_config_path: str = clickdc.option( + type=str, + help='Path to ssh configuration.', + default=os.path.expanduser('~') + '/.ssh/config', + ) + ssh_config_host: str | None = clickdc.option( + type=str, + help='Host to connect to ssh server reading from ssh configuration.', + ) + list_ssh_config: bool = clickdc.option( + is_flag=True, + help='list ssh configurations in the ssh config (requires paramiko).', + ) + ssh_warning_off: bool = clickdc.option( + is_flag=True, + help='Suppress the SSH deprecation notice.', + ) + ssl_mode: str = clickdc.option( + type=click.Choice(['auto', 'on', 'off']), + help='Set desired SSL behavior. auto=preferred if TCP/IP, on=required, off=off.', + ) + deprecated_ssl: bool | None = clickdc.option( + '--ssl/--no-ssl', + 'deprecated_ssl', + default=None, + clickdc=None, + help='Enable SSL for connection (automatically enabled with other flags).', + ) + ssl_ca: str | None = clickdc.option( + type=click.Path(exists=True), + help='CA file in PEM format.', + ) + ssl_capath: str | None = clickdc.option( + type=click.Path(exists=True, file_okay=False, dir_okay=True), + help='CA directory.', + ) + ssl_cert: str | None = clickdc.option( + type=click.Path(exists=True), + help='X509 cert in PEM format.', + ) + ssl_key: str | None = clickdc.option( + type=click.Path(exists=True), + help='X509 key in PEM format.', + ) + ssl_cipher: str | None = clickdc.option( + type=str, + help='SSL cipher to use.', + ) + tls_version: str | None = clickdc.option( + type=click.Choice(['TLSv1', 'TLSv1.1', 'TLSv1.2', 'TLSv1.3'], case_sensitive=False), + help='TLS protocol version for secure connection.', + ) + ssl_verify_server_cert: bool = clickdc.option( + is_flag=True, + help=("""Verify server's "Common Name" in its cert against hostname used when connecting. This option is disabled by default."""), + ) + verbose: int = clickdc.option( + '-v', + count=True, + help='More verbose output and feedback. Can be given multiple times.', + ) + quiet: bool = clickdc.option( + '-q', + is_flag=True, + help='Less verbose output and feedback.', + ) + dbname: str | None = clickdc.option( + '-D', + '--database', + 'dbname', + type=str, + clickdc=None, + help='Database or DSN to use for the connection.', + ) + dsn: str = clickdc.option( + '-d', + type=str, + default='', + envvar='MYSQL_DSN', + help='DSN alias configured in the ~/.myclirc file, or a full DSN.', + ) + list_dsn: bool = clickdc.option( + is_flag=True, + help='Show list of DSN aliases configured in the [alias_dsn] section of ~/.myclirc.', + ) + prompt: str | None = clickdc.option( + '-R', + type=str, + help=f'Prompt format (Default: "{DEFAULT_PROMPT}").', + ) + toolbar: str | None = clickdc.option( + type=str, + help='Toolbar format.', + ) + logfile: TextIOWrapper | None = clickdc.option( + '-l', + type=click.File(mode='a', encoding='utf-8'), + help='Log every query and its results to a file.', + ) + checkpoint: TextIOWrapper | None = clickdc.option( + type=click.File(mode='a', encoding='utf-8'), + help='In batch or --execute mode, log successful queries to a file, and skipped with --resume.', + ) + resume: bool = clickdc.option( + '--resume', + is_flag=True, + help='In batch mode, resume after replaying statements in the --checkpoint file.', + ) + defaults_group_suffix: str | None = clickdc.option( + type=str, + help='Read MySQL config groups with the specified suffix.', + ) + defaults_file: str | None = clickdc.option( + type=click.Path(), + help='Only read MySQL options from the given file.', + ) + myclirc: str = clickdc.option( + type=click.Path(), + default='~/.myclirc', + help='Location of myclirc file.', + ) + auto_vertical_output: bool = clickdc.option( + is_flag=True, + help='Automatically switch to vertical output mode if the result is wider than the terminal width.', + ) + show_warnings: bool | None = clickdc.option( + '--show-warnings/--no-show-warnings', + is_flag=True, + default=None, + clickdc=None, + help='Automatically show warnings after executing a SQL statement.', + ) + table: bool = clickdc.option( + '-t', + is_flag=True, + help='Shorthand for --format=table.', + ) + csv: bool = clickdc.option( + is_flag=True, + help='Shorthand for --format=csv.', + ) + warn: bool | None = clickdc.option( + '--warn/--no-warn', + default=None, + clickdc=None, + help='Warn before running a destructive query.', + ) + local_infile: bool | None = clickdc.option( + type=bool, + is_flag=False, + default=None, + help='Enable/disable LOAD DATA LOCAL INFILE.', + ) + login_path: str | None = clickdc.option( + '-g', + type=str, + help='Read this path from the login file.', + ) + execute: str | None = clickdc.option( + '-e', + type=str, + help='Execute command and quit.', + ) + init_command: str | None = clickdc.option( + type=str, + help='SQL statement to execute after connecting.', + ) + unbuffered: bool | None = clickdc.option( + is_flag=True, + help='Instead of copying every row of data into a buffer, fetch rows as needed, to save memory.', + ) + character_set: str | None = clickdc.option( + '--charset', + '--character-set', + 'character_set', + type=str, + help='Character set for MySQL session.', + ) + batch: str | None = clickdc.option( + type=str, + help='SQL script to execute in batch mode.', + ) + noninteractive: bool = clickdc.option( + is_flag=True, + help="Don't prompt during batch input. Recommended.", + ) + format: str | None = clickdc.option( + type=click.Choice(['default', 'csv', 'tsv', 'table']), + help='Format for batch or --execute output.', + ) + throttle: float = clickdc.option( + type=float, + default=0.0, + help='Pause in seconds between queries in batch mode.', + ) + progress: bool = clickdc.option( + is_flag=True, + help='Show progress on the standard error with --batch.', + ) + use_keyring: str | None = clickdc.option( + type=click.Choice(['true', 'false', 'reset']), + default=None, + help='Store and retrieve passwords from the system keyring: true/false/reset.', + ) + keepalive_ticks: int | None = clickdc.option( + type=int, + help='Send regular keepalive pings to the connection, roughly every seconds.', + ) + checkup: bool = clickdc.option( + is_flag=True, + help='Run a checkup on your configuration.', + ) + + +def get_password_from_file(password_file: str | None) -> str | None: + if not password_file: + return None + try: + with open(password_file) as fp: + return fp.readline().removesuffix('\n') + except FileNotFoundError: + click.secho(f"Password file '{password_file}' not found", err=True, fg='red') + sys.exit(1) + except PermissionError: + click.secho(f"Permission denied reading password file '{password_file}'", err=True, fg='red') + sys.exit(1) + except IsADirectoryError: + click.secho(f"Path '{password_file}' is a directory, not a file", err=True, fg='red') + sys.exit(1) + except Exception as e: + click.secho(f"Error reading password file '{password_file}': {str(e)}", err=True, fg='red') + sys.exit(1) + + +def preprocess_cli_args( + cli_args: CliArgs, + is_valid_connection_scheme: Callable[[str], tuple[bool, str | None]], +) -> int: + if cli_args.database is None and isinstance(cli_args.password, str) and '://' in cli_args.password: + is_valid_scheme, scheme = is_valid_connection_scheme(cli_args.password) + if not is_valid_scheme: + click.secho(f'Error: Unknown connection scheme provided for DSN URI ({scheme}://)', err=True, fg='red') + sys.exit(1) + cli_args.database = cli_args.password + cli_args.password = EMPTY_PASSWORD_FLAG_SENTINEL + + if cli_args.password is None and cli_args.password_file: + password_from_file = get_password_from_file(cli_args.password_file) + if password_from_file is not None: + cli_args.password = password_from_file + + if cli_args.password is None and os.environ.get('MYSQL_PWD') is not None: + cli_args.password = os.environ.get('MYSQL_PWD') + + if cli_args.resume and not cli_args.checkpoint: + click.secho('Error: --resume requires a --checkpoint file.', err=True, fg='red') + sys.exit(1) + + if cli_args.resume and not cli_args.batch: + click.secho('Error: --resume requires a --batch file.', err=True, fg='red') + sys.exit(1) + + if cli_args.verbose and cli_args.quiet: + click.secho('Error: --verbose and --quiet are incompatible.', err=True, fg='red') + sys.exit(1) + elif cli_args.verbose: + return int(cli_args.verbose) + elif cli_args.quiet: + return -1 + return 0 diff --git a/mycli/clibuffer.py b/mycli/clibuffer.py index 81353b63..edbc64cb 100644 --- a/mycli/clibuffer.py +++ b/mycli/clibuffer.py @@ -1,55 +1,54 @@ -from prompt_toolkit.enums import DEFAULT_BUFFER -from prompt_toolkit.filters import Condition from prompt_toolkit.application import get_app -from .packages import special +from prompt_toolkit.enums import DEFAULT_BUFFER +from prompt_toolkit.filters import Condition, Filter + +from mycli.packages.special import iocommands +from mycli.packages.special.main import COMMANDS as SPECIAL_COMMANDS -def cli_is_multiline(mycli): +def cli_is_multiline(mycli) -> Filter: @Condition def cond(): - doc = get_app().layout.get_buffer_by_name(DEFAULT_BUFFER).document - if not mycli.multi_line: return False else: + doc = get_app().layout.get_buffer_by_name(DEFAULT_BUFFER).document return not _multiline_exception(doc.text) + return cond -def _multiline_exception(text): +def _multiline_exception(text: str) -> bool: orig = text text = text.strip() + first_word = text.split()[0] if text else '' # Multi-statement favorite query is a special case. Because there will # be a semicolon separating statements, we can't consider semicolon an # EOL. Let's consider an empty line an EOL instead. - if text.startswith('\\fs'): - return orig.endswith('\n') + if first_word.startswith("\\fs"): + return orig.endswith("\n") return ( # Special Command - text.startswith('\\') or - - # Delimiter declaration - text.lower().startswith('delimiter') or - - # Ended with the current delimiter (usually a semi-column) - text.endswith(special.get_current_delimiter()) or - - text.endswith('\\g') or - text.endswith('\\G') or - text.endswith(r'\e') or - text.endswith(r'\clip') or - - # Exit doesn't need semi-column` - (text == 'exit') or - - # Quit doesn't need semi-column - (text == 'quit') or - - # To all teh vim fans out there - (text == ':q') or - + first_word.startswith("\\") + or text.endswith(( + # Ended with the current delimiter (usually a semi-column) + iocommands.get_current_delimiter(), + # or ended with certain commands + "\\g", + "\\G", + r"\e", + r"\edit", + r"\clip", + )) + or + # non-backslashed special commands such as "exit" or "help" don't need semicolon + first_word in SPECIAL_COMMANDS + or + # uppercase variants accepted + first_word.lower() in SPECIAL_COMMANDS + or # just a plain enter without any text - (text == '') + (first_word == "") ) diff --git a/mycli/clistyle.py b/mycli/clistyle.py index b0ac9922..c86694e8 100644 --- a/mycli/clistyle.py +++ b/mycli/clistyle.py @@ -1,81 +1,100 @@ import logging +from typing import cast -import pygments.styles -from pygments.token import string_to_tokentype, Token +from prompt_toolkit.styles import Style, merge_styles +from prompt_toolkit.styles.pygments import style_from_pygments_cls +from prompt_toolkit.styles.style import _MergedStyle from pygments.style import Style as PygmentsStyle +import pygments.styles +from pygments.token import Token, string_to_tokentype from pygments.util import ClassNotFound -from prompt_toolkit.styles.pygments import style_from_pygments_cls -from prompt_toolkit.styles import merge_styles, Style logger = logging.getLogger(__name__) # map Pygments tokens (ptk 1.0) to class names (ptk 2.0). -TOKEN_TO_PROMPT_STYLE = { - Token.Menu.Completions.Completion.Current: 'completion-menu.completion.current', - Token.Menu.Completions.Completion: 'completion-menu.completion', - Token.Menu.Completions.Meta.Current: 'completion-menu.meta.completion.current', - Token.Menu.Completions.Meta: 'completion-menu.meta.completion', - Token.Menu.Completions.MultiColumnMeta: 'completion-menu.multi-column-meta', - Token.Menu.Completions.ProgressButton: 'scrollbar.arrow', # best guess - Token.Menu.Completions.ProgressBar: 'scrollbar', # best guess - Token.SelectedText: 'selected', - Token.SearchMatch: 'search', - Token.SearchMatch.Current: 'search.current', - Token.Toolbar: 'bottom-toolbar', - Token.Toolbar.Off: 'bottom-toolbar.off', - Token.Toolbar.On: 'bottom-toolbar.on', - Token.Toolbar.Search: 'search-toolbar', - Token.Toolbar.Search.Text: 'search-toolbar.text', - Token.Toolbar.System: 'system-toolbar', - Token.Toolbar.Arg: 'arg-toolbar', - Token.Toolbar.Arg.Text: 'arg-toolbar.text', - Token.Toolbar.Transaction.Valid: 'bottom-toolbar.transaction.valid', - Token.Toolbar.Transaction.Failed: 'bottom-toolbar.transaction.failed', - Token.Output.Header: 'output.header', - Token.Output.OddRow: 'output.odd-row', - Token.Output.EvenRow: 'output.even-row', - Token.Output.Null: 'output.null', - Token.Prompt: 'prompt', - Token.Continuation: 'continuation', +TOKEN_TO_PROMPT_STYLE: dict[Token, str] = { + Token.Menu.Completions.Completion.Current: "completion-menu.completion.current", + Token.Menu.Completions.Completion: "completion-menu.completion", + Token.Menu.Completions.Meta.Current: "completion-menu.meta.completion.current", + Token.Menu.Completions.Meta: "completion-menu.meta.completion", + Token.Menu.Completions.MultiColumnMeta: "completion-menu.multi-column-meta", + Token.Menu.Completions.ProgressButton: "scrollbar.arrow", # best guess + Token.Menu.Completions.ProgressBar: "scrollbar", # best guess + Token.SelectedText: "selected", + Token.SearchMatch: "search", + Token.SearchMatch.Current: "search.current", + Token.MatchingBracket.Cursor: "matching-bracket.cursor", + Token.MatchingBracket.Other: "matching-bracket.other", + Token.Toolbar: "bottom-toolbar", + Token.Toolbar.Off: "bottom-toolbar.off", + Token.Toolbar.On: "bottom-toolbar.on", + Token.Toolbar.Search: "search-toolbar", + Token.Toolbar.Search.Text: "search-toolbar.text", + Token.Toolbar.System: "system-toolbar", + Token.Toolbar.Arg: "arg-toolbar", + Token.Toolbar.Arg.Text: "arg-toolbar.text", + Token.Toolbar.Transaction.Valid: "bottom-toolbar.transaction.valid", + Token.Toolbar.Transaction.Failed: "bottom-toolbar.transaction.failed", + Token.Output.TableSeparator: "output.table-separator", + Token.Output.Header: "output.header", + Token.Output.OddRow: "output.odd-row", + Token.Output.EvenRow: "output.even-row", + Token.Output.Null: "output.null", + Token.Output.Status: "output.status", + Token.Output.Status.WarningCount: "output.status.warning-count", + Token.Output.Timing: "output.timing", + Token.Warnings.TableSeparator: "warnings.table-separator", + Token.Warnings.Header: "warnings.header", + Token.Warnings.OddRow: "warnings.odd-row", + Token.Warnings.EvenRow: "warnings.even-row", + Token.Warnings.Null: "warnings.null", + Token.Warnings.Status: "warnings.status", + Token.Warnings.Status.WarningCount: "warnings.status.warning-count", + Token.Warnings.Timing: "warnings.timing", + Token.Prompt: "prompt", + Token.Continuation: "continuation", } # reverse dict for cli_helpers, because they still expect Pygments tokens. -PROMPT_STYLE_TO_TOKEN = { - v: k for k, v in TOKEN_TO_PROMPT_STYLE.items() -} +PROMPT_STYLE_TO_TOKEN: dict[str, Token] = {v: k for k, v in TOKEN_TO_PROMPT_STYLE.items()} # all tokens that the Pygments MySQL lexer can produce -OVERRIDE_STYLE_TO_TOKEN = { - 'sql.comment': Token.Comment, - 'sql.comment.multi-line': Token.Comment.Multiline, - 'sql.comment.single-line': Token.Comment.Single, - 'sql.comment.optimizer-hint': Token.Comment.Special, - 'sql.escape': Token.Error, - 'sql.keyword': Token.Keyword, - 'sql.datatype': Token.Keyword.Type, - 'sql.literal': Token.Literal, - 'sql.literal.date': Token.Literal.Date, - 'sql.symbol': Token.Name, - 'sql.quoted-schema-object': Token.Name.Quoted, - 'sql.quoted-schema-object.escape': Token.Name.Quoted.Escape, - 'sql.constant': Token.Name.Constant, - 'sql.function': Token.Name.Function, - 'sql.variable': Token.Name.Variable, - 'sql.number': Token.Number, - 'sql.number.binary': Token.Number.Bin, - 'sql.number.float': Token.Number.Float, - 'sql.number.hex': Token.Number.Hex, - 'sql.number.integer': Token.Number.Integer, - 'sql.operator': Token.Operator, - 'sql.punctuation': Token.Punctuation, - 'sql.string': Token.String, - 'sql.string.double-quouted': Token.String.Double, - 'sql.string.escape': Token.String.Escape, - 'sql.string.single-quoted': Token.String.Single, - 'sql.whitespace': Token.Text, +OVERRIDE_STYLE_TO_TOKEN: dict[str, Token] = { + "sql.comment": Token.Comment, + "sql.comment.multi-line": Token.Comment.Multiline, + "sql.comment.single-line": Token.Comment.Single, + "sql.comment.optimizer-hint": Token.Comment.Special, + "sql.escape": Token.Error, + "sql.keyword": Token.Keyword, + "sql.datatype": Token.Keyword.Type, + "sql.literal": Token.Literal, + "sql.literal.date": Token.Literal.Date, + "sql.symbol": Token.Name, + "sql.quoted-schema-object": Token.Name.Quoted, + "sql.quoted-schema-object.escape": Token.Name.Quoted.Escape, + "sql.constant": Token.Name.Constant, + "sql.function": Token.Name.Function, + "sql.variable": Token.Name.Variable, + "sql.number": Token.Number, + "sql.number.binary": Token.Number.Bin, + "sql.number.float": Token.Number.Float, + "sql.number.hex": Token.Number.Hex, + "sql.number.integer": Token.Number.Integer, + "sql.operator": Token.Operator, + "sql.punctuation": Token.Punctuation, + "sql.string": Token.String, + "sql.string.double-quouted": Token.String.Double, + "sql.string.escape": Token.String.Escape, + "sql.string.single-quoted": Token.String.Single, + "sql.whitespace": Token.Text, } -def parse_pygments_style(token_name, style_object, style_dict): + +def parse_pygments_style( + token_name: str, + style_object: type[PygmentsStyle] | PygmentsStyle | dict[object, str] | str, + style_dict: dict[str, str], +) -> tuple[Token, str]: """Parse token type and style string. :param token_name: str name of Pygments token. Example: "Token.String" @@ -84,66 +103,105 @@ def parse_pygments_style(token_name, style_object, style_dict): """ token_type = string_to_tokentype(token_name) - try: + if isinstance(style_object, type) and issubclass(style_object, PygmentsStyle): + # When a Pygments Style class is passed, use its "styles" mapping. + other_token_type = string_to_tokentype(style_dict[token_name]) + style_class = cast(type[PygmentsStyle], style_object) + return token_type, style_class.styles[other_token_type] + elif isinstance(style_object, PygmentsStyle): other_token_type = string_to_tokentype(style_dict[token_name]) return token_type, style_object.styles[other_token_type] - except AttributeError as err: + else: return token_type, style_dict[token_name] -def style_factory(name, cli_style): +def is_valid_pygments(name: str) -> bool: + try: + + class TestStyle(PygmentsStyle): + default_style = '' + styles = {Token.Default: name} + + return True + except AssertionError: + # can't emit error because some styles are valid pygments and not valid ptoolkit + return False + + +def is_valid_ptoolkit(name: str) -> bool: + try: + _s = Style([("default", name)]) + return True + except ValueError: + # can't emit error because some styles are valid pygments and not valid ptoolkit + return False + + +def style_factory_ptoolkit(name: str, cli_style: dict[str, str]) -> _MergedStyle: try: - style = pygments.styles.get_style_by_name(name) + style: PygmentsStyle = pygments.styles.get_style_by_name(name) except ClassNotFound: - style = pygments.styles.get_style_by_name('native') + style = pygments.styles.get_style_by_name("native") - prompt_styles = [] + prompt_styles: list[tuple[str, str]] = [] # prompt-toolkit used pygments tokens for styling before, switched to style # names in 2.0. Convert old token types to new style names, for backwards compatibility. for token in cli_style: - if token.startswith('Token.'): + if token.startswith("Token."): # treat as pygments token (1.0) - token_type, style_value = parse_pygments_style( - token, style, cli_style) + token_type, style_value = parse_pygments_style(token, style, cli_style) if token_type in TOKEN_TO_PROMPT_STYLE: prompt_style = TOKEN_TO_PROMPT_STYLE[token_type] - prompt_styles.append((prompt_style, style_value)) + if is_valid_ptoolkit(style_value): + prompt_styles.append((prompt_style, style_value)) else: # we don't want to support tokens anymore - logger.error('Unhandled style / class name: %s', token) + logger.error("Unhandled style / class name: %s", token) else: # treat as prompt style name (2.0). See default style names here: # https://github.com/jonathanslenders/python-prompt-toolkit/blob/master/prompt_toolkit/styles/defaults.py - prompt_styles.append((token, cli_style[token])) + if is_valid_ptoolkit(cli_style[token]): + prompt_styles.append((token, cli_style[token])) - override_style = Style([('bottom-toolbar', 'noreverse')]) - return merge_styles([ - style_from_pygments_cls(style), - override_style, - Style(prompt_styles) - ]) + override_style: Style = Style([("bottom-toolbar", "noreverse")]) + return merge_styles([style_from_pygments_cls(style), override_style, Style(prompt_styles)]) -def style_factory_output(name, cli_style): +def style_factory_helpers( + name: str, + cli_style: dict[str, str], + warnings: bool = False, +) -> PygmentsStyle: try: - style = pygments.styles.get_style_by_name(name).styles + style: dict[PygmentsStyle | str, str] = pygments.styles.get_style_by_name(name).styles except ClassNotFound: - style = pygments.styles.get_style_by_name('native').styles + style = pygments.styles.get_style_by_name("native").styles for token in cli_style: - if token.startswith('Token.'): - token_type, style_value = parse_pygments_style( - token, style, cli_style) - style.update({token_type: style_value}) + if token.startswith("Token."): + token_type, style_value = parse_pygments_style(token, style, cli_style) + if is_valid_pygments(style_value): + style.update({token_type: style_value}) elif token in PROMPT_STYLE_TO_TOKEN: token_type = PROMPT_STYLE_TO_TOKEN[token] - style.update({token_type: cli_style[token]}) + if is_valid_pygments(cli_style[token]): + style.update({token_type: cli_style[token]}) elif token in OVERRIDE_STYLE_TO_TOKEN: token_type = OVERRIDE_STYLE_TO_TOKEN[token] - style.update({token_type: cli_style[token]}) + if is_valid_pygments(cli_style[token]): + style.update({token_type: cli_style[token]}) else: # TODO: cli helpers will have to switch to ptk.Style - logger.error('Unhandled style / class name: %s', token) + logger.error("Unhandled style / class name: %s", token) + + if warnings: + for warning_token in list(style.keys()): + if 'Warnings' not in str(warning_token): + continue + warning_str = str(warning_token) + output_str = warning_str.replace('Warnings', 'Output') + output_token = string_to_tokentype(output_str) + style[output_token] = style[warning_token] class OutputStyle(PygmentsStyle): default_style = "" diff --git a/mycli/clitoolbar.py b/mycli/clitoolbar.py index 52b6ee45..80700415 100644 --- a/mycli/clitoolbar.py +++ b/mycli/clitoolbar.py @@ -1,57 +1,103 @@ -from prompt_toolkit.key_binding.vi_state import InputMode +from typing import Callable + from prompt_toolkit.application import get_app from prompt_toolkit.enums import EditingMode -from .packages import special +from prompt_toolkit.formatted_text import AnyFormattedText, to_formatted_text +from prompt_toolkit.key_binding.vi_state import InputMode +from mycli.packages import special -def create_toolbar_tokens_func(mycli, show_fish_help): + +def create_toolbar_tokens_func( + mycli, + show_initial_toolbar_help: Callable[[], bool], + format_string: str | None, + get_custom_toolbar: Callable[[str], AnyFormattedText], +) -> Callable[[], list[tuple[str, str]]]: """Return a function that generates the toolbar tokens.""" - def get_toolbar_tokens(): - result = [('class:bottom-toolbar', ' ')] - if mycli.multi_line: - delimiter = special.get_current_delimiter() - result.append( - ( - 'class:bottom-toolbar', - ' ({} [{}] will end the line) '.format( - 'Semi-colon' if delimiter == ';' else 'Delimiter', delimiter) - )) + def get_toolbar_tokens() -> list[tuple[str, str]]: + divider = ('class:bottom-toolbar', ' │ ') + + result = [("class:bottom-toolbar", "[Tab] Complete")] + dynamic = [] + + result.append(divider) + result.append(("class:bottom-toolbar", "[F1] Help")) + + if mycli.completer.smart_completion: + result.append(divider) + result.append(("class:bottom-toolbar", "[F2] Smart-complete:")) + result.append(("class:bottom-toolbar.on", "ON ")) + else: + result.append(divider) + result.append(("class:bottom-toolbar", "[F2] Smart-complete:")) + result.append(("class:bottom-toolbar.off", "OFF")) if mycli.multi_line: - result.append(('class:bottom-toolbar.on', '[F3] Multiline: ON ')) + result.append(divider) + result.append(("class:bottom-toolbar", "[F3] Multiline:")) + result.append(("class:bottom-toolbar.on", "ON ")) else: - result.append(('class:bottom-toolbar.off', - '[F3] Multiline: OFF ')) - if mycli.prompt_app.editing_mode == EditingMode.VI: - result.append(( - 'class:bottom-toolbar.on', - 'Vi-mode ({})'.format(_get_vi_mode()) - )) + result.append(divider) + result.append(("class:bottom-toolbar", "[F3] Multiline:")) + result.append(("class:bottom-toolbar.off", "OFF")) + + if mycli.prompt_session.editing_mode == EditingMode.VI: + result.append(divider) + result.append(("class:bottom-toolbar", "Vi:")) + result.append(("class:bottom-toolbar.on", _get_vi_mode())) if mycli.toolbar_error_message: - result.append( - ('class:bottom-toolbar', ' ' + mycli.toolbar_error_message)) + dynamic.append(divider) + dynamic.append(("class:bottom-toolbar.transaction.failed", mycli.toolbar_error_message)) mycli.toolbar_error_message = None - if show_fish_help(): - result.append( - ('class:bottom-toolbar', ' Right-arrow to complete suggestion')) + if mycli.multi_line: + delimiter = special.get_current_delimiter() + if delimiter != ';' or show_initial_toolbar_help(): + dynamic.append(divider) + dynamic.append(('class:bottom-toolbar', '"')) + dynamic.append(('class:bottom-toolbar.on', delimiter)) + dynamic.append(('class:bottom-toolbar', '" ends a statement')) + + if show_initial_toolbar_help(): + dynamic.append(divider) + dynamic.append(("class:bottom-toolbar", "right-arrow accepts full-line suggestion")) if mycli.completion_refresher.is_refreshing(): - result.append( - ('class:bottom-toolbar', ' Refreshing completions...')) + dynamic.append(divider) + dynamic.append(("class:bottom-toolbar", "Refreshing completions…")) + + schema_prefetcher = getattr(mycli, 'schema_prefetcher', None) + if schema_prefetcher is not None and schema_prefetcher.is_prefetching(): + dynamic.append(divider) + dynamic.append(("class:bottom-toolbar", "Prefetching schemas…")) + if format_string and format_string != r'\B': + if format_string.startswith(r'\B'): + amended_format = format_string[2:] + result.extend(dynamic) + dynamic = [] + result.append(('class:bottom-toolbar', '\n')) + else: + amended_format = format_string + result = [] + formatted = to_formatted_text(get_custom_toolbar(amended_format), style='class:bottom-toolbar') + result.extend([*formatted]) # coerce to list for mypy + + result.extend(dynamic) return result + return get_toolbar_tokens -def _get_vi_mode(): +def _get_vi_mode() -> str: """Get the current vi mode for display.""" return { - InputMode.INSERT: 'I', - InputMode.NAVIGATION: 'N', - InputMode.REPLACE: 'R', - InputMode.REPLACE_SINGLE: 'R', - InputMode.INSERT_MULTIPLE: 'M', + InputMode.INSERT: "I", + InputMode.NAVIGATION: "N", + InputMode.REPLACE: "R", + InputMode.REPLACE_SINGLE: "R", + InputMode.INSERT_MULTIPLE: "M", }[get_app().vi_state.input_mode] diff --git a/mycli/compat.py b/mycli/compat.py index 2ebfe07f..bca14261 100644 --- a/mycli/compat.py +++ b/mycli/compat.py @@ -2,5 +2,4 @@ import sys - -WIN = sys.platform in ('win32', 'cygwin') +WIN: bool = sys.platform in ("win32", "cygwin") diff --git a/mycli/completion_refresher.py b/mycli/completion_refresher.py index 5d5f40fc..81e74060 100644 --- a/mycli/completion_refresher.py +++ b/mycli/completion_refresher.py @@ -1,19 +1,27 @@ import threading -from .packages.special.main import COMMANDS -from collections import OrderedDict +from typing import Callable -from .sqlcompleter import SQLCompleter -from .sqlexecute import SQLExecute, ServerSpecies +import pymysql -class CompletionRefresher(object): +from mycli.packages.special.main import COMMANDS +from mycli.packages.sqlresult import SQLResult +from mycli.sqlcompleter import SQLCompleter +from mycli.sqlexecute import ServerSpecies, SQLExecute - refreshers = OrderedDict() - def __init__(self): - self._completer_thread = None +class CompletionRefresher: + refreshers: dict = {} + + def __init__(self) -> None: + self._completer_thread: threading.Thread | None = None self._restart_refresh = threading.Event() - def refresh(self, executor, callbacks, completer_options=None): + def refresh( + self, + executor: SQLExecute, + callbacks: Callable | list[Callable], + completer_options: dict | None = None, + ) -> list[SQLResult]: """Creates a SQLCompleter object and populates it with the relevant completion suggestions in a background thread. @@ -30,29 +38,47 @@ def refresh(self, executor, callbacks, completer_options=None): if self.is_refreshing(): self._restart_refresh.set() - return [(None, None, None, 'Auto-completion refresh restarted.')] + return [SQLResult(status="Auto-completion refresh restarted.")] else: self._completer_thread = threading.Thread( - target=self._bg_refresh, - args=(executor, callbacks, completer_options), - name='completion_refresh') + target=self._bg_refresh, args=(executor, callbacks, completer_options), name="completion_refresh" + ) self._completer_thread.daemon = True self._completer_thread.start() - return [(None, None, None, - 'Auto-completion refresh started in the background.')] + return [SQLResult(status="Auto-completion refresh started in the background.")] - def is_refreshing(self): - return self._completer_thread and self._completer_thread.is_alive() + def is_refreshing(self) -> bool: + return bool(self._completer_thread and self._completer_thread.is_alive()) - def _bg_refresh(self, sqlexecute, callbacks, completer_options): + def _bg_refresh( + self, + sqlexecute: SQLExecute, + callbacks: Callable | list[Callable], + completer_options: dict, + ) -> None: completer = SQLCompleter(**completer_options) - # Create a new pgexecute method to populate the completions. + # Create a new sqlexecute method to populate the completions. e = sqlexecute - executor = SQLExecute(e.dbname, e.user, e.password, e.host, e.port, - e.socket, e.charset, e.local_infile, e.ssl, - e.ssh_user, e.ssh_host, e.ssh_port, - e.ssh_password, e.ssh_key_filename) + try: + executor = SQLExecute( + e.dbname, + e.user, + e.password, + e.host, + e.port, + e.socket, + e.character_set, + e.local_infile, + e.ssl, + e.ssh_user, + e.ssh_host, + e.ssh_port, + e.ssh_password, + e.ssh_key_filename, + ) + except pymysql.err.OperationalError: + return # If callbacks is a single function then push it into a list. if callable(callbacks): @@ -76,55 +102,95 @@ def _bg_refresh(self, sqlexecute, callbacks, completer_options): for callback in callbacks: callback(completer) -def refresher(name, refreshers=CompletionRefresher.refreshers): + executor.close() + + +def refresher(name: str, refreshers: dict = CompletionRefresher.refreshers) -> Callable: """Decorator to add the decorated function to the dictionary of refreshers. Any function decorated with a @refresher will be executed as part of the completion refresh routine.""" + def wrapper(wrapped): refreshers[name] = wrapped return wrapped + return wrapper -@refresher('databases') -def refresh_databases(completer, executor): + +@refresher("databases") +def refresh_databases(completer: SQLCompleter, executor: SQLExecute) -> None: completer.extend_database_names(executor.databases()) -@refresher('schemata') -def refresh_schemata(completer, executor): + +@refresher("schemata") +def refresh_schemata(completer: SQLCompleter, executor: SQLExecute) -> None: # schemata - In MySQL Schema is the same as database. But for mycli # schemata will be the name of the current database. completer.extend_schemata(executor.dbname) completer.set_dbname(executor.dbname) -@refresher('tables') -def refresh_tables(completer, executor): - completer.extend_relations(executor.tables(), kind='tables') - completer.extend_columns(executor.table_columns(), kind='tables') -@refresher('users') -def refresh_users(completer, executor): +@refresher("tables") +def refresh_tables(completer: SQLCompleter, executor: SQLExecute) -> None: + table_columns_dbresult = list(executor.table_columns()) + completer.extend_relations(table_columns_dbresult, kind="tables") + completer.extend_columns(table_columns_dbresult, kind="tables") + + +@refresher("foreign_keys") +def refresh_foreign_keys(completer: SQLCompleter, executor: SQLExecute) -> None: + completer.extend_foreign_keys(executor.foreign_keys()) + + +@refresher("enum_values") +def refresh_enum_values(completer: SQLCompleter, executor: SQLExecute) -> None: + completer.extend_enum_values(executor.enum_values()) + + +@refresher("users") +def refresh_users(completer: SQLCompleter, executor: SQLExecute) -> None: completer.extend_users(executor.users()) + # @refresher('views') -# def refresh_views(completer, executor): +# def refresh_views(completer: SQLCompleter, executor: SQLExecute) -> None: # completer.extend_relations(executor.views(), kind='views') # completer.extend_columns(executor.view_columns(), kind='views') -@refresher('functions') -def refresh_functions(completer, executor): + +@refresher("functions") +def refresh_functions(completer: SQLCompleter, executor: SQLExecute) -> None: completer.extend_functions(executor.functions()) - if executor.server_info.species == ServerSpecies.TiDB: + if executor.server_info and executor.server_info.species == ServerSpecies.TiDB: completer.extend_functions(completer.tidb_functions, builtin=True) -@refresher('special_commands') -def refresh_special(completer, executor): - completer.extend_special_commands(COMMANDS.keys()) -@refresher('show_commands') -def refresh_show_commands(completer, executor): +@refresher("procedures") +def refresh_procedures(completer: SQLCompleter, executor: SQLExecute) -> None: + completer.extend_procedures(executor.procedures()) + + +@refresher("character_sets") +def refresh_character_sets(completer: SQLCompleter, executor: SQLExecute) -> None: + completer.extend_character_sets(executor.character_sets()) + + +@refresher("collations") +def refresh_collations(completer: SQLCompleter, executor: SQLExecute) -> None: + completer.extend_collations(executor.collations()) + + +@refresher("special_commands") +def refresh_special(completer: SQLCompleter, executor: SQLExecute) -> None: + completer.extend_special_commands(list(COMMANDS.keys())) + + +@refresher("show_commands") +def refresh_show_commands(completer: SQLCompleter, executor: SQLExecute) -> None: completer.extend_show_items(executor.show_candidates()) -@refresher('keywords') -def refresh_keywords(completer, executor): - if executor.server_info.species == ServerSpecies.TiDB: + +@refresher("keywords") +def refresh_keywords(completer: SQLCompleter, executor: SQLExecute) -> None: + if executor.server_info and executor.server_info.species == ServerSpecies.TiDB: completer.extend_keywords(completer.tidb_keywords, replace=True) diff --git a/mycli/config.py b/mycli/config.py index 5d711093..a79b1021 100644 --- a/mycli/config.py +++ b/mycli/config.py @@ -1,40 +1,29 @@ from copy import copy +from importlib import resources from io import BytesIO, TextIOWrapper import logging import os from os.path import exists import struct import sys -from typing import Union, IO +from typing import IO, BinaryIO, Literal from configobj import ConfigObj, ConfigObjError -import pyaes - -try: - import importlib.resources as resources -except ImportError: - # Python < 3.7 - import importlib_resources as resources - -try: - basestring -except NameError: - basestring = str - +from Cryptodome.Cipher import AES logger = logging.getLogger(__name__) -def log(logger, level, message): +def log(logger: logging.Logger, level: int, message: str) -> None: """Logs message to stderr if logging isn't initialized.""" - if logger.parent.name != 'root': - logger.log(level, message) - else: + if logger.parent and logger.parent.name == "root": print(message, file=sys.stderr) + logger.log(level, message) + -def read_config_file(f, list_values=True): +def read_config_file(f: str | IO[str], list_values: bool = True) -> ConfigObj | None: """Read a config file. *list_values* set to `True` is the default behavior of ConfigObj. @@ -45,26 +34,23 @@ def read_config_file(f, list_values=True): """ - if isinstance(f, basestring): + if isinstance(f, str): f = os.path.expanduser(f) try: - config = ConfigObj(f, interpolation=False, encoding='utf8', - list_values=list_values) + config = ConfigObj(f, interpolation=False, encoding="utf8", list_values=list_values) except ConfigObjError as e: - log(logger, logging.WARNING, "Unable to parse line {0} of config file " - "'{1}'.".format(e.line_number, f)) + log(logger, logging.WARNING, "Unable to parse line {0} of config file '{1}'.".format(e.line_number, f)) log(logger, logging.WARNING, "Using successfully parsed config values.") return e.config except (IOError, OSError) as e: - log(logger, logging.WARNING, "You don't have permission to read " - "config file '{0}'.".format(e.filename)) + log(logger, logging.WARNING, "You don't have permission to read config file '{0}'.".format(e.filename)) return None return config -def get_included_configs(config_file: Union[str, TextIOWrapper]) -> list: +def get_included_configs(config_file: str | IO[str]) -> list[str | IO[str]]: """Get a list of configuration files that are included into config_path with !includedir directive. @@ -76,29 +62,38 @@ def get_included_configs(config_file: Union[str, TextIOWrapper]) -> list: """ if not isinstance(config_file, str) or not os.path.isfile(config_file): return [] - included_configs = [] + included_configs: list[str | IO[str]] = [] try: with open(config_file) as f: - include_directives = filter( - lambda s: s.startswith('!includedir'), - f - ) - dirs = map(lambda s: s.strip().split()[-1], include_directives) - dirs = filter(os.path.isdir, dirs) - for dir in dirs: - for filename in os.listdir(dir): - if filename.endswith('.cnf'): - included_configs.append(os.path.join(dir, filename)) + include_directives = filter(lambda s: s.startswith("!includedir"), f) + dirs_split = (s.strip().split()[-1] for s in include_directives) + dirs = filter(os.path.isdir, dirs_split) + for dir_ in dirs: + for filename in os.listdir(dir_): + if filename.endswith(".cnf"): + included_configs.append(os.path.join(dir_, filename)) except (PermissionError, UnicodeDecodeError): pass return included_configs -def read_config_files(files, list_values=True): +def read_config_files( + files: list[str | IO[str]], + list_values: bool = True, + ignore_package_defaults: bool = False, + ignore_user_options: bool = False, +) -> ConfigObj: """Read and merge a list of config files.""" - config = create_default_config(list_values=list_values) + if ignore_package_defaults: + config = ConfigObj() + else: + config = create_default_config(list_values=list_values) + + if ignore_user_options: + return config + _files = copy(files) while _files: _file = _files.pop(0) @@ -108,38 +103,41 @@ def read_config_files(files, list_values=True): # (otherwise we'll just encounter the same errors again) if config is not None: _files = get_included_configs(_file) + _files - if bool(_config) is True: + if _config is not None: config.merge(_config) config.filename = _config.filename return config -def create_default_config(list_values=True): +def create_default_config(list_values: bool = True) -> ConfigObj: import mycli - default_config_file = resources.open_text(mycli, 'myclirc') + + default_config_file = resources.files(mycli).joinpath("myclirc").open('r') return read_config_file(default_config_file, list_values=list_values) -def write_default_config(destination, overwrite=False): +def write_default_config(destination: str, overwrite: bool = False) -> None: import mycli - default_config = resources.read_text(mycli, 'myclirc') + + with resources.files(mycli).joinpath("myclirc").open('r') as f: + default_config = f.read() destination = os.path.expanduser(destination) if not overwrite and exists(destination): return - with open(destination, 'w') as f: + with open(destination, "w") as f: f.write(default_config) -def get_mylogin_cnf_path(): +def get_mylogin_cnf_path() -> str | None: """Return the path to the login path file or None if it doesn't exist.""" - mylogin_cnf_path = os.getenv('MYSQL_TEST_LOGIN_FILE') + mylogin_cnf_path = os.getenv("MYSQL_TEST_LOGIN_FILE") if mylogin_cnf_path is None: - app_data = os.getenv('APPDATA') - default_dir = os.path.join(app_data, 'MySQL') if app_data else '~' - mylogin_cnf_path = os.path.join(default_dir, '.mylogin.cnf') + app_data = os.getenv("APPDATA") + default_dir = os.path.join(app_data, "MySQL") if app_data else "~" + mylogin_cnf_path = os.path.join(default_dir, ".mylogin.cnf") mylogin_cnf_path = os.path.expanduser(mylogin_cnf_path) @@ -149,7 +147,7 @@ def get_mylogin_cnf_path(): return None -def open_mylogin_cnf(name): +def open_mylogin_cnf(name: str) -> TextIOWrapper | None: """Open a readable version of .mylogin.cnf. Returns the file contents as a TextIOWrapper object. @@ -159,21 +157,21 @@ def open_mylogin_cnf(name): """ try: - with open(name, 'rb') as f: + with open(name, "rb") as f: plaintext = read_and_decrypt_mylogin_cnf(f) except (OSError, IOError, ValueError): - logger.error('Unable to open login path file.') + logger.error("Unable to open login path file.") return None if not isinstance(plaintext, BytesIO): - logger.error('Unable to read login path file.') + logger.error("Unable to read login path file.") return None return TextIOWrapper(plaintext) # TODO reuse code between encryption an decryption -def encrypt_mylogin_cnf(plaintext: IO[str]): +def encrypt_mylogin_cnf(plaintext: IO[str]) -> BytesIO: """Encryption of .mylogin.cnf file, analogous to calling mysql_config_editor. @@ -181,23 +179,21 @@ def encrypt_mylogin_cnf(plaintext: IO[str]): https://github.com/isotopp/mysql-config-coder """ - def realkey(key): + + def realkey(key: bytes) -> bytes: """Create the AES key from the login key.""" rkey = bytearray(16) for i in range(len(key)): rkey[i % 16] ^= key[i] return bytes(rkey) - def encode_line(plaintext, real_key, buf_len): - aes = pyaes.AESModeOfOperationECB(real_key) + def encode_line(plaintext: str, real_key: bytes, buf_len: int) -> bytes: + aes = AES.new(real_key, AES.MODE_ECB) text_len = len(plaintext) pad_len = buf_len - text_len pad_chr = bytes(chr(pad_len), "utf8") - plaintext = plaintext.encode() + pad_chr * pad_len - encrypted_text = b''.join( - [aes.encrypt(plaintext[i: i + 16]) - for i in range(0, len(plaintext), 16)] - ) + plaintext_b = plaintext.encode() + pad_chr * pad_len + encrypted_text = b"".join([aes.encrypt(plaintext_b[i : i + 16]) for i in range(0, len(plaintext_b), 16)]) return encrypted_text LOGIN_KEY_LENGTH = 20 @@ -224,7 +220,7 @@ def encode_line(plaintext, real_key, buf_len): return outfile -def read_and_decrypt_mylogin_cnf(f): +def read_and_decrypt_mylogin_cnf(f: BinaryIO) -> BytesIO | None: """Read and decrypt the contents of .mylogin.cnf. This decryption algorithm mimics the code in MySQL's @@ -248,7 +244,7 @@ def read_and_decrypt_mylogin_cnf(f): buf = f.read(4) if not buf or len(buf) != 4: - logger.error('Login path file is blank or incomplete.') + logger.error("Login path file is blank or incomplete.") return None # Read the login key. @@ -258,87 +254,83 @@ def read_and_decrypt_mylogin_cnf(f): rkey = [0] * 16 for i in range(LOGIN_KEY_LEN): try: - rkey[i % 16] ^= ord(key[i:i+1]) + rkey[i % 16] ^= ord(key[i : i + 1]) except TypeError: # ord() was unable to get the value of the byte. - logger.error('Unable to generate login path AES key.') + logger.error("Unable to generate login path AES key.") return None - rkey = struct.pack('16B', *rkey) + rkey_b = struct.pack("16B", *rkey) # Create a bytes buffer to hold the plaintext. plaintext = BytesIO() - aes = pyaes.AESModeOfOperationECB(rkey) + aes = AES.new(rkey_b, AES.MODE_ECB) while True: # Read the length of the ciphertext. len_buf = f.read(MAX_CIPHER_STORE_LEN) if len(len_buf) < MAX_CIPHER_STORE_LEN: break - cipher_len, = struct.unpack(" bool: """Convert a string value to its corresponding boolean value.""" if isinstance(s, bool): return s - elif not isinstance(s, basestring): - raise TypeError('argument must be a string') + elif not isinstance(s, str): + raise TypeError("argument must be a string") - true_values = ('true', 'on', '1') - false_values = ('false', 'off', '0') + true_values = ("true", "on", "1") + false_values = ("false", "off", "0") if s.lower() in true_values: return True elif s.lower() in false_values: return False else: - raise ValueError('not a recognized boolean value: {0}'.format(s)) + raise ValueError(f'not a recognized boolean value: {s}') -def strip_matching_quotes(s): +def strip_matching_quotes(s: str) -> str: """Remove matching, surrounding quotes from a string. This is the same logic that ConfigObj uses when parsing config values. """ - if (isinstance(s, basestring) and len(s) >= 2 and - s[0] == s[-1] and s[0] in ('"', "'")): + if isinstance(s, str) and len(s) >= 2 and s[0] == s[-1] and s[0] in ('"', "'"): s = s[1:-1] return s -def _remove_pad(line): +def _remove_pad(line: bytes) -> bytes | Literal[False]: """Remove the pad from the *line*.""" try: # Determine pad length. pad_length = ord(line[-1:]) except TypeError: # ord() was unable to get the value of the byte. - logger.warning('Unable to remove pad.') + logger.warning("Unable to remove pad.") return False if pad_length > len(line) or len(set(line[-pad_length:])) != 1: # Pad length should be less than or equal to the length of the # plaintext. The pad should have a single unique byte. - logger.warning('Invalid pad found in login path file.') + logger.warning("Invalid pad found in login path file.") return False return line[:-pad_length] diff --git a/mycli/constants.py b/mycli/constants.py new file mode 100644 index 00000000..f6ef1900 --- /dev/null +++ b/mycli/constants.py @@ -0,0 +1,19 @@ +HOME_URL = 'https://mycli.net' +REPO_URL = 'https://github.com/dbcli/mycli' +DOCS_URL = f'{HOME_URL}/docs' +ISSUES_URL = f'{REPO_URL}/issues' + +DEFAULT_CHARSET = 'utf8mb4' +DEFAULT_DATABASE = 'mysql' +DEFAULT_HOST = 'localhost' +DEFAULT_PORT = 3306 +DEFAULT_USER = 'root' + +TEST_DATABASE = 'mycli_test_db' + +DEFAULT_WIDTH = 80 +DEFAULT_HEIGHT = 25 + +# MySQL error codes not available in pymysql.constants.ER +ER_MUST_CHANGE_PASSWORD_LOGIN = 1862 +ER_MUST_CHANGE_PASSWORD = 1820 diff --git a/mycli/key_bindings.py b/mycli/key_bindings.py index b084849d..950a9af1 100644 --- a/mycli/key_bindings.py +++ b/mycli/key_bindings.py @@ -1,116 +1,309 @@ +from functools import partial import logging +import webbrowser + +import prompt_toolkit +from prompt_toolkit.application.current import get_app from prompt_toolkit.enums import EditingMode -from prompt_toolkit.filters import completion_is_selected, emacs_mode +from prompt_toolkit.filters import ( + Condition, + completion_is_selected, + control_is_searchable, + emacs_mode, +) from prompt_toolkit.key_binding import KeyBindings +from prompt_toolkit.key_binding.bindings.named_commands import register as ptoolkit_register +from prompt_toolkit.key_binding.key_processor import KeyPressEvent +from prompt_toolkit.selection import SelectionType -from .packages.toolkit.fzf import search_history +from mycli.constants import DOCS_URL +from mycli.packages import key_binding_utils +from mycli.packages.ptoolkit.fzf import search_history +from mycli.packages.ptoolkit.utils import safe_invalidate_display _logger = logging.getLogger(__name__) -def mycli_bindings(mycli): +@Condition +def ctrl_d_condition() -> bool: + """Ctrl-D exit binding is only active when the buffer is empty.""" + app = get_app() + return not app.current_buffer.text + + +@Condition +def in_completion() -> bool: + app = get_app() + return bool(app.current_buffer.complete_state) + + +def print_f1_help(): + app = get_app() + app.print_text('\n') + app.print_text([ + ('', 'Inline help — type "'), + ('bold', 'help'), + ('', '" or "'), + ('bold', r'\?'), + ('', '"\n'), + ]) + app.print_text([ + ('', 'Docs index — '), + ('bold', DOCS_URL), + ('', '\n'), + ]) + app.print_text('\n') + + +@ptoolkit_register("edit-and-execute-command") +def edit_and_execute(event: KeyPressEvent) -> None: + """Different from the prompt-toolkit default, we want to have a choice not + to execute a query after editing, hence validate_and_handle=False.""" + buff = event.current_buffer + buff.open_in_editor(validate_and_handle=False) + + +def mycli_bindings(mycli) -> KeyBindings: """Custom key bindings for mycli.""" kb = KeyBindings() - @kb.add('f2') - def _(event): + @kb.add('f1') + def _(event: KeyPressEvent) -> None: + """Open browser to documentation index.""" + _logger.debug('Detected F1 key.') + webbrowser.open_new_tab(DOCS_URL) + prompt_toolkit.application.run_in_terminal(print_f1_help) + safe_invalidate_display(event.app) + + @kb.add('escape', '[', 'P') + def _(event: KeyPressEvent) -> None: + """Open browser to documentation index.""" + _logger.debug("Detected alternate F1 key sequence.") + webbrowser.open_new_tab(DOCS_URL) + prompt_toolkit.application.run_in_terminal(print_f1_help) + safe_invalidate_display(event.app) + + @kb.add("f2") + def _(_event: KeyPressEvent) -> None: """Enable/Disable SmartCompletion Mode.""" - _logger.debug('Detected F2 key.') + _logger.debug("Detected F2 key.") mycli.completer.smart_completion = not mycli.completer.smart_completion - @kb.add('f3') - def _(event): + @kb.add('escape', '[', 'Q') + def _(_event: KeyPressEvent) -> None: + """Enable/Disable SmartCompletion Mode.""" + _logger.debug("Detected alternate F2 key sequence.") + mycli.completer.smart_completion = not mycli.completer.smart_completion + + @kb.add("f3") + def _(_event: KeyPressEvent) -> None: + """Enable/Disable Multiline Mode.""" + _logger.debug("Detected F3 key.") + mycli.multi_line = not mycli.multi_line + + @kb.add('escape', '[', 'R') + def _(_event: KeyPressEvent) -> None: """Enable/Disable Multiline Mode.""" - _logger.debug('Detected F3 key.') + _logger.debug('Detected alternate F3 key sequence.') mycli.multi_line = not mycli.multi_line - @kb.add('f4') - def _(event): + @kb.add("f4") + def _(event: KeyPressEvent) -> None: """Toggle between Vi and Emacs mode.""" - _logger.debug('Detected F4 key.') + _logger.debug("Detected F4 key.") if mycli.key_bindings == "vi": event.app.editing_mode = EditingMode.EMACS mycli.key_bindings = "emacs" + event.app.ttimeoutlen = mycli.emacs_ttimeoutlen else: event.app.editing_mode = EditingMode.VI mycli.key_bindings = "vi" + event.app.ttimeoutlen = mycli.vi_ttimeoutlen + + @kb.add('escape', '[', 'S') + def _(event: KeyPressEvent) -> None: + """Toggle between Vi and Emacs mode.""" + _logger.debug('Detected alternate F4 key sequence.') + if mycli.key_bindings == 'vi': + event.app.editing_mode = EditingMode.EMACS + mycli.key_bindings = 'emacs' + event.app.ttimeoutlen = mycli.emacs_ttimeoutlen + else: + event.app.editing_mode = EditingMode.VI + mycli.key_bindings = 'vi' + event.app.ttimeoutlen = mycli.vi_ttimeoutlen - @kb.add('tab') - def _(event): - """Force autocompletion at cursor.""" - _logger.debug('Detected key.') + @kb.add("tab") + def _(event: KeyPressEvent) -> None: + """Complete action at cursor.""" + _logger.debug("Detected key.") b = event.app.current_buffer + + behaviors = mycli.config['keys'].as_list('tab') + + if 'toolkit_default' in behaviors: + if b.complete_state: + b.complete_next() + else: + b.start_completion(select_first=True) + if b.complete_state: - b.complete_next() - else: + if 'advance' in behaviors: + b.complete_next() + elif 'cancel' in behaviors: + b.cancel_completion() + return + + if 'advancing_summon' in behaviors: b.start_completion(select_first=True) + elif 'prefixing_summon' in behaviors: + b.start_completion(insert_common_part=True) + elif 'summon' in behaviors: + b.start_completion(select_first=False) + + @kb.add("escape", eager=True, filter=in_completion) + def _(event: KeyPressEvent) -> None: + """Cancel completion menu. + + There will be a lag when canceling Escape due to the processing of + Alt- keystrokes as Escape- sequences. - @kb.add('c-space') - def _(event): + There will be no lag when using control-g to cancel.""" + event.app.current_buffer.cancel_completion() + + @kb.add("c-space") + def _(event: KeyPressEvent) -> None: """ - Initialize autocompletion at cursor. + Complete action at cursor. - If the autocompletion menu is not showing, display it with the + By default, if the autocompletion menu is not showing, display it with the appropriate completions for the context. If the menu is showing, select the next completion. """ - _logger.debug('Detected key.') + _logger.debug("Detected key.") b = event.app.current_buffer + + behaviors = mycli.config['keys'].as_list('control_space') + + if 'toolkit_default' in behaviors: + if b.text: + b.start_selection(selection_type=SelectionType.CHARACTERS) + return + if b.complete_state: - b.complete_next() - else: + if 'advance' in behaviors: + b.complete_next() + elif 'cancel' in behaviors: + b.cancel_completion() + return + + if 'advancing_summon' in behaviors: + b.start_completion(select_first=True) + elif 'prefixing_summon' in behaviors: + b.start_completion(insert_common_part=True) + elif 'summon' in behaviors: b.start_completion(select_first=False) - @kb.add('c-x', 'p', filter=emacs_mode) - def _(event): + @kb.add("c-x", "p", filter=emacs_mode) + def _(event: KeyPressEvent) -> None: """ Prettify and indent current statement, usually into multiple lines. Only accepts buffers containing single SQL statements. """ - _logger.debug('Detected /> key.') + _logger.debug("Detected /> key.") b = event.app.current_buffer - cursorpos_relative = b.cursor_position / max(1, len(b.text)) - pretty_text = mycli.handle_prettify_binding(b.text) - if len(pretty_text) > 0: - b.text = pretty_text - cursorpos_abs = int(round(cursorpos_relative * len(b.text))) - while 0 < cursorpos_abs < len(b.text) \ - and b.text[cursorpos_abs] in (' ', '\n'): - cursorpos_abs -= 1 - b.cursor_position = min(cursorpos_abs, len(b.text)) - - @kb.add('c-x', 'u', filter=emacs_mode) - def _(event): + if b.text: + b.transform_region(0, len(b.text), partial(key_binding_utils.handle_prettify_binding, mycli)) + + @kb.add("c-x", "u", filter=emacs_mode) + def _(event: KeyPressEvent) -> None: """ Unprettify and dedent current statement, usually into one line. Only accepts buffers containing single SQL statements. """ - _logger.debug('Detected /< key.') + _logger.debug("Detected /< key.") b = event.app.current_buffer - cursorpos_relative = b.cursor_position / max(1, len(b.text)) - unpretty_text = mycli.handle_unprettify_binding(b.text) - if len(unpretty_text) > 0: - b.text = unpretty_text - cursorpos_abs = int(round(cursorpos_relative * len(b.text))) - while 0 < cursorpos_abs < len(b.text) \ - and b.text[cursorpos_abs] in (' ', '\n'): - cursorpos_abs -= 1 - b.cursor_position = min(cursorpos_abs, len(b.text)) - - @kb.add('c-r', filter=emacs_mode) - def _(event): - """Search history using fzf or default reverse incremental search.""" - _logger.debug('Detected key.') - search_history(event) - - @kb.add('enter', filter=completion_is_selected) - def _(event): + if b.text: + b.transform_region(0, len(b.text), partial(key_binding_utils.handle_unprettify_binding, mycli)) + + @kb.add("c-o", "d", filter=emacs_mode) + def _(event: KeyPressEvent) -> None: + """ + Insert the current date. + """ + _logger.debug("Detected key.") + + event.app.current_buffer.insert_text(key_binding_utils.server_date(mycli.sqlexecute)) + + @kb.add("c-o", "c-d", filter=emacs_mode) + def _(event: KeyPressEvent) -> None: + """ + Insert the quoted current date. + """ + _logger.debug("Detected key.") + + event.app.current_buffer.insert_text(key_binding_utils.server_date(mycli.sqlexecute, quoted=True)) + + @kb.add("c-o", "t", filter=emacs_mode) + def _(event: KeyPressEvent) -> None: + """ + Insert the current datetime. + """ + _logger.debug("Detected key.") + + event.app.current_buffer.insert_text(key_binding_utils.server_datetime(mycli.sqlexecute)) + + @kb.add("c-o", "c-t", filter=emacs_mode) + def _(event: KeyPressEvent) -> None: + """ + Insert the quoted current datetime. + """ + _logger.debug("Detected key.") + + event.app.current_buffer.insert_text(key_binding_utils.server_datetime(mycli.sqlexecute, quoted=True)) + + @kb.add("c-r", filter=control_is_searchable) + def _(event: KeyPressEvent) -> None: + """Search history using fzf or reverse incremental search.""" + _logger.debug("Detected key.") + mode = mycli.config.get('keys', {}).get('control_r', 'auto') + if mode == 'reverse_isearch': + search_history(event, incremental=True) + else: + search_history( + event, + highlight_preview=mycli.highlight_preview, + highlight_style=mycli.syntax_style, + ) + + @kb.add("escape", "r", filter=control_is_searchable & emacs_mode) + def _(event: KeyPressEvent) -> None: + """Search history using fzf when available.""" + _logger.debug("Detected key.") + search_history( + event, + highlight_preview=mycli.highlight_preview, + highlight_style=mycli.syntax_style, + ) + + @kb.add('c-d', filter=ctrl_d_condition) + def _(event: KeyPressEvent) -> None: + """Exit mycli or ignore keypress.""" + _logger.debug('Detected key on empty line.') + mode = mycli.config.get('keys', {}).get('control_d', 'exit') + if mode == 'exit': + event.app.exit(exception=EOFError, style='class:exiting') + else: + event.app.output.bell() + + @kb.add("enter", filter=completion_is_selected) + def _(event: KeyPressEvent) -> None: """Makes the enter key work as the tab key only when showing the menu. In other words, don't execute query when enter is pressed in @@ -118,20 +311,20 @@ def _(event): (accept current selection). """ - _logger.debug('Detected enter key.') + _logger.debug("Detected enter key.") event.current_buffer.complete_state = None b = event.app.current_buffer b.complete_state = None - @kb.add('escape', 'enter') - def _(event): + @kb.add("escape", "enter") + def _(event: KeyPressEvent) -> None: """Introduces a line break in multi-line mode, or dispatches the command in single-line mode.""" - _logger.debug('Detected alt-enter key.') + _logger.debug("Detected alt-enter key.") if mycli.multi_line: event.app.current_buffer.validate_and_handle() else: - event.app.current_buffer.insert_text('\n') + event.app.current_buffer.insert_text("\n") return kb diff --git a/mycli/lexer.py b/mycli/lexer.py index 4b14d72d..3350d11f 100644 --- a/mycli/lexer.py +++ b/mycli/lexer.py @@ -7,6 +7,5 @@ class MyCliLexer(MySqlLexer): """Extends MySQL lexer to add keywords.""" tokens = { - 'root': [(r'\brepair\b', Keyword), - (r'\boffset\b', Keyword), inherit], + "root": [(r"\brepair\b", Keyword), (r"\boffset\b", Keyword), inherit], } diff --git a/mycli/magic.py b/mycli/magic.py deleted file mode 100644 index e1611bcc..00000000 --- a/mycli/magic.py +++ /dev/null @@ -1,62 +0,0 @@ -from .main import MyCli -import sql.parse -import sql.connection -import logging - -_logger = logging.getLogger(__name__) - -def load_ipython_extension(ipython): - - # This is called via the ipython command '%load_ext mycli.magic'. - - # First, load the sql magic if it isn't already loaded. - if not ipython.find_line_magic('sql'): - ipython.run_line_magic('load_ext', 'sql') - - # Register our own magic. - ipython.register_magic_function(mycli_line_magic, 'line', 'mycli') - -def mycli_line_magic(line): - _logger.debug('mycli magic called: %r', line) - parsed = sql.parse.parse(line, {}) - # "get" was renamed to "set" in ipython-sql: - # https://github.com/catherinedevlin/ipython-sql/commit/f4283c65aaf68f961e84019e8b939e4a3c501d43 - if hasattr(sql.connection.Connection, "get"): - conn = sql.connection.Connection.get(parsed["connection"]) - else: - try: - conn = sql.connection.Connection.set(parsed["connection"]) - # a new positional argument was added to Connection.set in version 0.4.0 of ipython-sql - except TypeError: - conn = sql.connection.Connection.set(parsed["connection"], False) - try: - # A corresponding mycli object already exists - mycli = conn._mycli - _logger.debug('Reusing existing mycli') - except AttributeError: - mycli = MyCli() - u = conn.session.engine.url - _logger.debug('New mycli: %r', str(u)) - - mycli.connect(host=u.host, port=u.port, passwd=u.password, database=u.database, user=u.username, init_command=None) - conn._mycli = mycli - - # For convenience, print the connection alias - print('Connected: {}'.format(conn.name)) - - try: - mycli.run_cli() - except SystemExit: - pass - - if not mycli.query_history: - return - - q = mycli.query_history[-1] - if q.mutating: - _logger.debug('Mutating query detected -- ignoring') - return - - if q.successful: - ipython = get_ipython() - return ipython.run_cell_magic('sql', line, q.query) diff --git a/mycli/main.py b/mycli/main.py index 4c194ced..bbe8f5d4 100755 --- a/mycli/main.py +++ b/mycli/main.py @@ -1,131 +1,143 @@ -from collections import defaultdict -from io import open +from __future__ import annotations + +from io import TextIOWrapper +import logging import os +import re import sys -import shutil -import traceback -import logging import threading -import re -import stat -import fileinput -from collections import namedtuple +import traceback +from typing import IO, Any, Generator, Iterable, Literal + try: from pwd import getpwuid except ImportError: pass -from time import time -from datetime import datetime -from random import choice +from textwrap import dedent +from urllib.parse import parse_qs, unquote, urlparse -from pymysql import OperationalError from cli_helpers.tabular_output import TabularOutputFormatter -from cli_helpers.tabular_output import preprocessors -from cli_helpers.utils import strip_ansi +from cli_helpers.tabular_output.output_formatter import MISSING_VALUE as _DEFAULT_MISSING_VALUE import click +import clickdc +import keyring +from prompt_toolkit.formatted_text import ( + to_formatted_text, +) +from prompt_toolkit.shortcuts import PromptSession +import pymysql +from pymysql.constants.CR import CR_SERVER_LOST +from pymysql.constants.ER import ACCESS_DENIED_ERROR, HANDSHAKE_ERROR +from pymysql.cursors import Cursor import sqlparse -import sqlglot -from mycli.packages.parseutils import is_dropping_database, is_destructive -from prompt_toolkit.completion import DynamicCompleter -from prompt_toolkit.enums import DEFAULT_BUFFER, EditingMode -from prompt_toolkit.key_binding.bindings.named_commands import register as prompt_register -from prompt_toolkit.shortcuts import PromptSession, CompleteStyle -from prompt_toolkit.document import Document -from prompt_toolkit.filters import HasFocus, IsDone -from prompt_toolkit.formatted_text import ANSI -from prompt_toolkit.layout.processors import (HighlightMatchingBracketProcessor, - ConditionalProcessor) -from prompt_toolkit.lexers import PygmentsLexer -from prompt_toolkit.auto_suggest import AutoSuggestFromHistory - -from .packages.special.main import NO_QUERY -from .packages.prompt_utils import confirm, confirm_destructive_query -from .packages.tabular_output import sql_format -from .packages import special -from .packages.special.favoritequeries import FavoriteQueries -from .packages.toolkit.history import FileHistoryWithTimestamp -from .sqlcompleter import SQLCompleter -from .clitoolbar import create_toolbar_tokens_func -from .clistyle import style_factory, style_factory_output -from .sqlexecute import FIELD_TYPES, SQLExecute, ERROR_CODE_ACCESS_DENIED -from .clibuffer import cli_is_multiline -from .completion_refresher import CompletionRefresher -from .config import (write_default_config, get_mylogin_cnf_path, - open_mylogin_cnf, read_config_files, str_to_bool, - strip_matching_quotes) -from .key_bindings import mycli_bindings -from .lexer import MyCliLexer -from . import __version__ -from .compat import WIN -from .packages.filepaths import dir_path_exists, guess_socket_location - -import itertools - -click.disable_unicode_literals_warning = True -try: - from urlparse import urlparse - from urlparse import unquote -except ImportError: - from urllib.parse import urlparse - from urllib.parse import unquote - -try: - import importlib.resources as resources -except ImportError: - # Python < 3.7 - import importlib_resources as resources - -try: - import paramiko -except ImportError: - from mycli.packages.paramiko_stub import paramiko - -# Query tuples are used for maintaining history -Query = namedtuple('Query', ['query', 'successful', 'mutating']) - -SUPPORT_INFO = ( - 'Home: http://mycli.net\n' - 'Bug tracker: https://github.com/dbcli/mycli/issues' +import mycli as mycli_package +from mycli.app_state import ( + AppStateMixin, + configure_prompt_state, + destructive_keywords_from_config, + ensure_my_cnf_sections, + llm_prompt_truncation, + normalize_ssl_mode, ) - - -class MyCli(object): - - default_prompt = '\\t \\u@\\h:\\d> ' - default_prompt_splitln = '\\u@\\h\\n(\\t):\\d>' +from mycli.cli_args import ( + DEFAULT_PROMPT, + EMPTY_PASSWORD_FLAG_SENTINEL, + CliArgs, + preprocess_cli_args, +) +from mycli.clistyle import style_factory_helpers, style_factory_ptoolkit +from mycli.compat import WIN +from mycli.completion_refresher import CompletionRefresher +from mycli.config import get_mylogin_cnf_path, open_mylogin_cnf, read_config_files, str_to_bool, write_default_config +from mycli.constants import ( + DEFAULT_CHARSET, + DEFAULT_HOST, + DEFAULT_PORT, + ER_MUST_CHANGE_PASSWORD_LOGIN, + ISSUES_URL, + REPO_URL, +) +from mycli.main_modes import repl as repl_package +from mycli.main_modes.batch import ( + main_batch_from_stdin, + main_batch_with_progress_bar, + main_batch_without_progress_bar, +) +from mycli.main_modes.checkup import main_checkup +from mycli.main_modes.execute import main_execute_from_cli +from mycli.main_modes.list_dsn import main_list_dsn +from mycli.main_modes.list_ssh_config import main_list_ssh_config +from mycli.main_modes.repl import main_repl, set_all_external_titles +from mycli.output import OutputMixin +from mycli.packages import special +from mycli.packages.cli_utils import filtered_sys_argv, is_valid_connection_scheme +from mycli.packages.filepaths import dir_path_exists, guess_socket_location +from mycli.packages.interactive_utils import confirm_destructive_query +from mycli.packages.special.favoritequeries import FavoriteQueries +from mycli.packages.special.main import ArgType, SpecialCommandAlias +from mycli.packages.sqlresult import SQLResult +from mycli.packages.ssh_utils import read_ssh_config +from mycli.packages.tabular_output import sql_format +from mycli.schema_prefetcher import SchemaPrefetcher +from mycli.sqlcompleter import SQLCompleter +from mycli.sqlexecute import SQLExecute +from mycli.types import Query + +sqlparse.engine.grouping.MAX_GROUPING_DEPTH = None # type: ignore[assignment] +sqlparse.engine.grouping.MAX_GROUPING_TOKENS = None # type: ignore[assignment] + +DEFAULT_MISSING_VALUE = _DEFAULT_MISSING_VALUE + + +class MyCli(AppStateMixin, OutputMixin): + default_prompt = DEFAULT_PROMPT + default_prompt_splitln = "\\u@\\h\\n(\\t):\\d>" max_len_prompt = 45 defaults_suffix = None + prompt_lines: int # In order of being loaded. Files lower in list override earlier ones. - cnf_files = [ - '/etc/my.cnf', - '/etc/mysql/my.cnf', - '/usr/local/etc/my.cnf', - os.path.expanduser('~/.my.cnf'), + cnf_files: list[str | IO[str]] = [ + "/etc/my.cnf", + "/etc/mysql/my.cnf", + "/usr/local/etc/my.cnf", + os.path.expanduser("~/.my.cnf"), ] # check XDG_CONFIG_HOME exists and not an empty string - if os.environ.get("XDG_CONFIG_HOME"): - xdg_config_home = os.environ.get("XDG_CONFIG_HOME") - else: - xdg_config_home = "~/.config" - system_config_files = [ - '/etc/myclirc', - os.path.join(os.path.expanduser(xdg_config_home), "mycli", "myclirc") + xdg_config_home = os.environ.get("XDG_CONFIG_HOME", "~/.config") + system_config_files: list[str | IO[str]] = [ + "/etc/myclirc", + os.path.join(os.path.expanduser(xdg_config_home), "mycli", "myclirc"), ] pwd_config_file = os.path.join(os.getcwd(), ".myclirc") - def __init__(self, sqlexecute=None, prompt=None, - logfile=None, defaults_suffix=None, defaults_file=None, - login_path=None, auto_vertical_output=False, warn=None, - myclirc="~/.myclirc"): + def __init__( + self, + sqlexecute: SQLExecute | None = None, + prompt: str | None = None, + toolbar_format: str | None = None, + logfile: TextIOWrapper | Literal[False] | None = None, + defaults_suffix: str | None = None, + defaults_file: str | None = None, + login_path: str | None = None, + auto_vertical_output: bool = False, + warn: bool | None = None, + myclirc: str = "~/.myclirc", + show_warnings: bool | None = None, + cli_verbosity: int = 0, + ) -> None: self.sqlexecute = sqlexecute self.logfile = logfile self.defaults_suffix = defaults_suffix self.login_path = login_path - self.toolbar_error_message = None + self.toolbar_error_message: str | None = None + self.prompt_session: PromptSession | None = None + self._keepalive_counter = 0 + self.keepalive_ticks: int | None = 0 + self.sandbox_mode: bool = False # self.cnf_files is a class variable that stores the list of mysql # config files to read in at launch. @@ -135,71 +147,102 @@ def __init__(self, sqlexecute=None, prompt=None, self.cnf_files = [defaults_file] # Load config. - config_files = (self.system_config_files + - [myclirc] + [self.pwd_config_file]) + config_files: list[str | IO[str]] = self.system_config_files + [myclirc] + [self.pwd_config_file] c = self.config = read_config_files(config_files) - self.multi_line = c['main'].as_bool('multi_line') - self.key_bindings = c['main']['key_bindings'] - special.set_timing_enabled(c['main'].as_bool('timing')) - self.beep_after_seconds = float(c['main']['beep_after_seconds'] or 0) + # this parallel config exists to + # * compare with my.cnf + # * support the --checkup feature + # todo: after removing my.cnf, create the parallel configs only when --checkup is set + self.config_without_package_defaults = read_config_files(config_files, ignore_package_defaults=True) + # this parallel config exists to compare with my.cnf support the --checkup feature + self.config_without_user_options = read_config_files(config_files, ignore_user_options=True) + self.multi_line = c["main"].as_bool("multi_line") + self.key_bindings = c["main"]["key_bindings"] + self.emacs_ttimeoutlen = c['keys'].as_float('emacs_ttimeoutlen') + self.vi_ttimeoutlen = c['keys'].as_float('vi_ttimeoutlen') + special.set_timing_enabled(c["main"].as_bool("timing")) + special.set_show_favorite_query(c["main"].as_bool("show_favorite_query")) + if show_warnings is not None: + special.set_show_warnings_enabled(show_warnings) + else: + special.set_show_warnings_enabled(c['main'].as_bool('show_warnings')) + self.beep_after_seconds = float(c["main"]["beep_after_seconds"] or 0) + self.default_keepalive_ticks = c['connection'].as_int('default_keepalive_ticks') FavoriteQueries.instance = FavoriteQueries.from_config(self.config) - self.dsn_alias = None - self.formatter = TabularOutputFormatter( - format_name=c['main']['table_format']) - sql_format.register_new_formatter(self.formatter) - self.formatter.mycli = self - self.syntax_style = c['main']['syntax_style'] - self.less_chatty = c['main'].as_bool('less_chatty') - self.cli_style = c['colors'] - self.output_style = style_factory_output( - self.syntax_style, - self.cli_style - ) - self.wider_completion_menu = c['main'].as_bool('wider_completion_menu') - c_dest_warning = c['main'].as_bool('destructive_warning') + self.dsn_alias: str | None = None + self.main_formatter = TabularOutputFormatter(format_name=c["main"]["table_format"]) + self.redirect_formatter = TabularOutputFormatter(format_name=c["main"].get("redirect_format", "csv")) + sql_format.register_new_formatter(self.main_formatter) + sql_format.register_new_formatter(self.redirect_formatter) + self.main_formatter.mycli = self + self.redirect_formatter.mycli = self + self.syntax_style = c["main"]["syntax_style"] + self.verbosity = -1 if c["main"].as_bool("less_chatty") else 0 + if cli_verbosity: + self.verbosity = cli_verbosity + self.cli_style = c["colors"] + self.ptoolkit_style = style_factory_ptoolkit(self.syntax_style, self.cli_style) + self.helpers_style = style_factory_helpers(self.syntax_style, self.cli_style) + self.helpers_warnings_style = style_factory_helpers(self.syntax_style, self.cli_style, warnings=True) + self.wider_completion_menu = c["main"].as_bool("wider_completion_menu") + c_dest_warning = c["main"].as_bool("destructive_warning") self.destructive_warning = c_dest_warning if warn is None else warn - self.login_path_as_host = c['main'].as_bool('login_path_as_host') + self.login_path_as_host = c["main"].as_bool("login_path_as_host") + self.post_redirect_command = c['main'].get('post_redirect_command') + self.null_string = c['main'].get('null_string') + self.numeric_alignment = c['main'].get('numeric_alignment', 'right') + self.binary_display = c['main'].get('binary_display') + self.llm_prompt_field_truncate, self.llm_prompt_section_truncate = llm_prompt_truncation(c) + + self.ssl_mode, ssl_mode_error = normalize_ssl_mode(c) + if ssl_mode_error: + self.echo(ssl_mode_error, err=True, fg="red") # read from cli argument or user config file - self.auto_vertical_output = auto_vertical_output or \ - c['main'].as_bool('auto_vertical_output') + self.auto_vertical_output = auto_vertical_output or c["main"].as_bool("auto_vertical_output") # Write user config if system config wasn't the last config loaded. if c.filename not in self.system_config_files and not os.path.exists(myclirc): write_default_config(myclirc) # audit log - if self.logfile is None and 'audit_log' in c['main']: + if self.logfile is None and "audit_log" in c["main"]: try: - self.logfile = open(os.path.expanduser(c['main']['audit_log']), 'a') - except (IOError, OSError) as e: - self.echo('Error: Unable to open the audit log file. Your queries will not be logged.', - err=True, fg='red') + self.logfile = open(os.path.expanduser(c["main"]["audit_log"]), "a") + except (IOError, OSError): + self.echo("Error: Unable to open the audit log file. Your queries will not be logged.", err=True, fg="red") self.logfile = False self.completion_refresher = CompletionRefresher() + self.prefetch_schemas_mode = c["main"].get("prefetch_schemas_mode", "always") or "always" + raw_prefetch_list = c["main"].as_list("prefetch_schemas_list") if "prefetch_schemas_list" in c["main"] else [] + self.prefetch_schemas_list = [s.strip() for s in raw_prefetch_list if s and s.strip()] + self.schema_prefetcher = SchemaPrefetcher(self) self.logger = logging.getLogger(__name__) self.initialize_logging() - prompt_cnf = self.read_my_cnf_files(self.cnf_files, ['prompt'])['prompt'] - self.prompt_format = prompt or prompt_cnf or c['main']['prompt'] or \ - self.default_prompt - self.multiline_continuation_char = c['main']['prompt_continuation'] - keyword_casing = c['main'].get('keyword_casing', 'auto') + keyword_casing = c["main"].get("keyword_casing", "auto") + + self.highlight_preview = c['search'].as_bool('highlight_preview') - self.query_history = [] + self.query_history: list[Query] = [] # Initialize completer. - self.smart_completion = c['main'].as_bool('smart_completion') + self.smart_completion = c["main"].as_bool("smart_completion") self.completer = SQLCompleter( - self.smart_completion, - supported_formats=self.formatter.supported_formats, - keyword_casing=keyword_casing) + self.smart_completion, supported_formats=self.main_formatter.supported_formats, keyword_casing=keyword_casing + ) self._completer_lock = threading.Lock() + self.min_completion_trigger = c["main"].as_int("min_completion_trigger") + # a hack, pending a better way to handle settings and state + repl_package.MIN_COMPLETION_TRIGGER = self.min_completion_trigger + self.last_prompt_message = to_formatted_text('') + self.last_custom_toolbar_message = to_formatted_text('') + # Register custom special commands. self.register_special_commands() @@ -212,989 +255,638 @@ def __init__(self, sqlexecute=None, prompt=None, self.cnf_files.append(mylogin_cnf) elif mylogin_cnf_path and not mylogin_cnf: # There was an error reading the login path file. - print('Error: Unable to read login path file.') - - self.prompt_app = None - - def register_special_commands(self): - special.register_special_command(self.change_db, 'use', - '\\u', 'Change to a new database.', aliases=('\\u',)) - special.register_special_command(self.change_db, 'connect', - '\\r', 'Reconnect to the database. Optional database argument.', - aliases=('\\r', ), case_sensitive=True) - special.register_special_command(self.refresh_completions, 'rehash', - '\\#', 'Refresh auto-completions.', arg_type=NO_QUERY, aliases=('\\#',)) + print("Error: Unable to read login path file.") + + self.my_cnf = read_config_files(self.cnf_files, list_values=False) + ensure_my_cnf_sections(self.my_cnf) + prompt_cnf = self.read_my_cnf(self.my_cnf, ["prompt"])["prompt"] + configure_prompt_state(self, c, prompt, prompt_cnf, toolbar_format) + self.prompt_session = None + self.destructive_keywords = destructive_keywords_from_config(c) + special.set_destructive_keywords(self.destructive_keywords) + + def close(self) -> None: + if hasattr(self, 'schema_prefetcher'): + self.schema_prefetcher.stop() + if self.sqlexecute is not None: + self.sqlexecute.close() + + def register_special_commands(self) -> None: special.register_special_command( - self.change_table_format, 'tableformat', '\\T', - 'Change the table format used to output results.', - aliases=('\\T',), case_sensitive=True) - special.register_special_command(self.execute_from_file, 'source', '\\. filename', - 'Execute commands from file.', aliases=('\\.',)) - special.register_special_command(self.change_prompt_format, 'prompt', - '\\R', 'Change prompt format.', aliases=('\\R',), case_sensitive=True) - - def change_table_format(self, arg, **_): + self.change_db, + "use", + "use ", + "Change to a new database.", + aliases=[SpecialCommandAlias("\\u", case_sensitive=False)], + ) + special.register_special_command( + self.manual_reconnect, + "connect", + "connect [database]", + "Reconnect to the server, optionally switching databases.", + case_sensitive=True, + aliases=[SpecialCommandAlias("\\r", case_sensitive=True)], + ) + special.register_special_command( + self.refresh_completions, + "rehash", + "rehash", + "Refresh auto-completions.", + arg_type=ArgType.NO_QUERY, + aliases=[SpecialCommandAlias("\\#", case_sensitive=False)], + ) + special.register_special_command( + self.change_table_format, + "tableformat", + "tableformat ", + "Change the table format used to output interactive results.", + case_sensitive=True, + aliases=[SpecialCommandAlias("\\T", case_sensitive=True)], + ) + special.register_special_command( + self.change_redirect_format, + "redirectformat", + "redirectformat ", + "Change the table format used to output redirected results.", + case_sensitive=True, + aliases=[SpecialCommandAlias("\\Tr", case_sensitive=True)], + ) + special.register_special_command( + self.execute_from_file, + "source", + "source ", + "Execute queries from a file.", + aliases=[SpecialCommandAlias("\\.", case_sensitive=False)], + ) + special.register_special_command( + self.change_prompt_format, + "prompt", + "prompt ", + "Change prompt format.", + case_sensitive=True, + aliases=[SpecialCommandAlias("\\R", case_sensitive=True)], + ) + + def manual_reconnect(self, arg: str = "", **_) -> Generator[SQLResult, None, None]: + """ + Interactive method to use for the \r command, so that the utility method + may be cleanly used elsewhere. + """ + if not self.reconnect(database=arg): + yield SQLResult(status="Not connected") + elif not arg or arg == '``': + yield SQLResult() + else: + yield self.change_db(arg).send(None) + + def change_table_format(self, arg: str, **_) -> Generator[SQLResult, None, None]: + try: + self.main_formatter.format_name = arg + yield SQLResult(status=f"Changed table format to {arg}") + except ValueError: + msg = f"Table format {arg} not recognized. Allowed formats:" + for table_type in self.main_formatter.supported_formats: + msg += f"\n\t{table_type}" + yield SQLResult(status=msg) + + def change_redirect_format(self, arg: str, **_) -> Generator[SQLResult, None, None]: try: - self.formatter.format_name = arg - yield (None, None, None, - 'Changed table format to {}'.format(arg)) + self.redirect_formatter.format_name = arg + yield SQLResult(status=f"Changed redirect format to {arg}") except ValueError: - msg = 'Table format {} not recognized. Allowed formats:'.format( - arg) - for table_type in self.formatter.supported_formats: - msg += "\n\t{}".format(table_type) - yield (None, None, None, msg) + msg = f"Redirect format {arg} not recognized. Allowed formats:" + for table_type in self.redirect_formatter.supported_formats: + msg += f"\n\t{table_type}" + yield SQLResult(status=msg) + + def change_db(self, arg: str, **_) -> Generator[SQLResult, None, None]: + if arg.startswith("`") and arg.endswith("`"): + arg = re.sub(r"^`(.*)`$", r"\1", arg) + arg = re.sub(r"``", r"`", arg) - def change_db(self, arg, **_): if not arg: - click.secho( - "No database selected", - err=True, fg="red" - ) + click.secho("No database selected", err=True, fg="red") return - if arg.startswith('`') and arg.endswith('`'): - arg = re.sub(r'^`(.*)`$', r'\1', arg) - arg = re.sub(r'``', r'`', arg) - self.sqlexecute.change_db(arg) + assert isinstance(self.sqlexecute, SQLExecute) - yield (None, None, None, 'You are now connected to database "%s" as ' - 'user "%s"' % (self.sqlexecute.dbname, self.sqlexecute.user)) + if self.sqlexecute.dbname == arg: + msg = f'You are already connected to database "{self.sqlexecute.dbname}" as user "{self.sqlexecute.user}"' + else: + self.sqlexecute.change_db(arg) + msg = f'You are now connected to database "{self.sqlexecute.dbname}" as user "{self.sqlexecute.user}"' + + # todo: this jump back to repl.py is a sign that separation is incomplete. + # also: it should not be needed. Don't titles update on every new prompt? + set_all_external_titles(self) - def execute_from_file(self, arg, **_): + yield SQLResult(status=msg) + + def execute_from_file(self, arg: str, **_) -> Iterable[SQLResult]: if not arg: - message = 'Missing required argument, filename.' - return [(None, None, None, message)] + message = "Missing required argument: filename." + return [SQLResult(status=message)] try: with open(os.path.expanduser(arg)) as f: query = f.read() except IOError as e: - return [(None, None, None, str(e))] + return [SQLResult(status=str(e))] - if (self.destructive_warning and - confirm_destructive_query(query) is False): - message = 'Wise choice. Command execution stopped.' - return [(None, None, None, message)] + if self.destructive_warning and confirm_destructive_query(self.destructive_keywords, query) is False: + message = "Wise choice. Command execution stopped." + return [SQLResult(status=message)] + assert isinstance(self.sqlexecute, SQLExecute) return self.sqlexecute.run(query) - def change_prompt_format(self, arg, **_): + def change_prompt_format(self, arg: str, **_) -> list[SQLResult]: """ Change the prompt format. """ if not arg: - message = 'Missing required argument, format.' - return [(None, None, None, message)] - - self.prompt_format = self.get_prompt(arg) - return [(None, None, None, "Changed prompt format to %s" % arg)] - - def initialize_logging(self): - - log_file = os.path.expanduser(self.config['main']['log_file']) - log_level = self.config['main']['log_level'] - - level_map = {'CRITICAL': logging.CRITICAL, - 'ERROR': logging.ERROR, - 'WARNING': logging.WARNING, - 'INFO': logging.INFO, - 'DEBUG': logging.DEBUG - } + message = "Missing required argument, format." + return [SQLResult(status=message)] + + self.prompt_format = arg + return [SQLResult(status=f"Changed prompt format to {arg}")] + + def initialize_logging(self) -> None: + log_file = os.path.expanduser(self.config["main"]["log_file"]) + log_level = self.config["main"]["log_level"] + + level_map = { + "CRITICAL": logging.CRITICAL, + "ERROR": logging.ERROR, + "WARNING": logging.WARNING, + "INFO": logging.INFO, + "DEBUG": logging.DEBUG, + } # Disable logging if value is NONE by switching to a no-op handler # Set log level to a high value so it doesn't even waste cycles getting called. if log_level.upper() == "NONE": - handler = logging.NullHandler() + handler: logging.Handler = logging.NullHandler() log_level = "CRITICAL" elif dir_path_exists(log_file): handler = logging.FileHandler(log_file) else: - self.echo( - 'Error: Unable to open the log file "{}".'.format(log_file), - err=True, fg='red') + self.echo(f'Error: Unable to open the log file "{log_file}".', err=True, fg="red") return - formatter = logging.Formatter( - '%(asctime)s (%(process)d/%(threadName)s) ' - '%(name)s %(levelname)s - %(message)s') + formatter = logging.Formatter("%(asctime)s (%(process)d/%(threadName)s) %(name)s %(levelname)s - %(message)s") handler.setFormatter(formatter) - root_logger = logging.getLogger('mycli') + root_logger = logging.getLogger("mycli") root_logger.addHandler(handler) root_logger.setLevel(level_map[log_level.upper()]) logging.captureWarnings(True) - root_logger.debug('Initializing mycli logging.') - root_logger.debug('Log file %r.', log_file) - - - def read_my_cnf_files(self, files, keys): - """ - Reads a list of config files and merges them. The last one will win. - :param files: list of files to read - :param keys: list of keys to retrieve - :returns: tuple, with None for missing keys. - """ - cnf = read_config_files(files, list_values=False) - - sections = ['client', 'mysqld'] - key_transformations = { - 'mysqld': { - 'socket': 'default_socket', - 'port': 'default_port', - 'user': 'default_user', - }, - } - - if self.login_path and self.login_path != 'client': - sections.append(self.login_path) - - if self.defaults_suffix: - sections.extend([sect + self.defaults_suffix for sect in sections]) - - configuration = defaultdict(lambda: None) - for key in keys: - for section in cnf: - if ( - section not in sections or - key not in cnf[section] - ): - continue - new_key = key_transformations.get(section, {}).get(key) or key - configuration[new_key] = strip_matching_quotes( - cnf[section][key]) - - return configuration - - - def merge_ssl_with_cnf(self, ssl, cnf): - """Merge SSL configuration dict with cnf dict""" - - merged = {} - merged.update(ssl) - prefix = 'ssl-' - for k, v in cnf.items(): - # skip unrelated options - if not k.startswith(prefix): - continue - if v is None: - continue - # special case because PyMySQL argument is significantly different - # from commandline - if k == 'ssl-verify-server-cert': - merged['check_hostname'] = v - else: - # use argument name just strip "ssl-" prefix - arg = k[len(prefix):] - merged[arg] = v - - return merged - - def connect(self, database='', user='', passwd='', host='', port='', - socket='', charset='', local_infile='', ssl='', - ssh_user='', ssh_host='', ssh_port='', - ssh_password='', ssh_key_filename='', init_command='', password_file=''): - - cnf = {'database': None, - 'user': None, - 'password': None, - 'host': None, - 'port': None, - 'socket': None, - 'default_socket': None, - 'default-character-set': None, - 'local-infile': None, - 'loose-local-infile': None, - 'ssl-ca': None, - 'ssl-cert': None, - 'ssl-key': None, - 'ssl-cipher': None, - 'ssl-verify-serer-cert': None, + root_logger.debug("Initializing mycli logging.") + root_logger.debug("Log file %r.", log_file) + + def connect( + self, + database: str | None = "", + user: str | None = "", + passwd: str | int | None = None, + host: str | None = "", + port: str | int | None = "", + socket: str | None = "", + character_set: str | None = "", + local_infile: bool | None = False, + ssl: dict[str, Any] | None = None, + ssh_user: str | None = "", + ssh_host: str | None = "", + ssh_port: int = 22, + ssh_password: str | None = "", + ssh_key_filename: str | None = "", + init_command: str | None = "", + unbuffered: bool | None = None, + use_keyring: bool | None = None, + reset_keyring: bool | None = None, + keepalive_ticks: int | None = None, + ) -> None: + cnf = { + "database": None, + "user": None, + "password": None, + "host": None, + "port": None, + "socket": None, + "default_socket": None, + "default-character-set": None, + "local-infile": None, + "loose-local-infile": None, + "ssl-ca": None, + "ssl-cert": None, + "ssl-key": None, + "ssl-cipher": None, + "ssl-verify-server-cert": None, } - cnf = self.read_my_cnf_files(self.cnf_files, cnf.keys()) + cnf = self.read_my_cnf(self.my_cnf, list(cnf.keys())) # Fall back to config values only if user did not specify a value. - database = database or cnf['database'] - user = user or cnf['user'] or os.getenv('USER') - host = host or cnf['host'] - port = port or cnf['port'] - ssl = ssl or {} - - port = port and int(port) - if not port: - port = 3306 - if not host or host == 'localhost': + database = database or cnf["database"] + user = user or cnf["user"] or os.getenv("USER") + host = host or cnf["host"] + port = port or cnf["port"] + ssl_config: dict[str, Any] = ssl or {} + user_connection_config = self.config_without_package_defaults.get('connection', {}) + self.keepalive_ticks = keepalive_ticks + + int_port = port and int(port) + if not int_port: + int_port = DEFAULT_PORT + if not host or host == DEFAULT_HOST: socket = ( - socket or - cnf['socket'] or - cnf['default_socket'] or - guess_socket_location() + socket + or user_connection_config.get("default_socket") + or cnf["socket"] + or cnf["default_socket"] + or guess_socket_location() ) - - passwd = passwd if isinstance(passwd, str) else cnf['password'] - charset = charset or cnf['default-character-set'] or 'utf8' + passwd = passwd if isinstance(passwd, (str, int)) else cnf["password"] + + # default_character_set doesn't check in self.config_without_package_defaults, because the + # option already existed before the my.cnf deprecation. For the same reason, + # default_character_set can be in [connection] or [main]. + if not character_set: + if 'default_character_set' in self.config['connection']: + character_set = self.config['connection']['default_character_set'] + elif 'default_character_set' in self.config['main']: + character_set = self.config['main']['default_character_set'] + elif 'default_character_set' in cnf: + character_set = cnf['default_character_set'] + elif 'default-character-set' in cnf: + character_set = cnf['default-character-set'] + if not character_set: + character_set = DEFAULT_CHARSET # Favor whichever local_infile option is set. - for local_infile_option in (local_infile, cnf['local-infile'], - cnf['loose-local-infile'], False): + use_local_infile = False + for local_infile_option in ( + local_infile, + user_connection_config.get('default_local_infile'), + cnf['local_infile'], + cnf['local-infile'], + cnf['loose_local_infile'], + cnf['loose-local-infile'], + False, + ): try: - local_infile = str_to_bool(local_infile_option) + use_local_infile = str_to_bool(local_infile_option or '') break except (TypeError, ValueError): pass - ssl = self.merge_ssl_with_cnf(ssl, cnf) - # prune lone check_hostname=False - if not any(v for v in ssl.values()): - ssl = None - - # if the passwd is not specified try to set it using the password_file option - password_from_file = self.get_password_from_file(password_file) - passwd = passwd or password_from_file + # temporary my.cnf override mappings + if 'default_ssl_ca' in user_connection_config: + cnf['ssl-ca'] = user_connection_config.get('default_ssl_ca') or None + if 'default_ssl_cert' in user_connection_config: + cnf['ssl-cert'] = user_connection_config.get('default_ssl_cert') or None + if 'default_ssl_key' in user_connection_config: + cnf['ssl-key'] = user_connection_config.get('default_ssl_key') or None + if 'default_ssl_cipher' in user_connection_config: + cnf['ssl-cipher'] = user_connection_config.get('default_ssl_cipher') or None + if 'default_ssl_verify_server_cert' in user_connection_config: + cnf['ssl-verify-server-cert'] = user_connection_config.get('default_ssl_verify_server_cert') or None + + # todo: rewrite the merge method using self.config['connection'] instead of cnf, after removing my.cnf support + ssl_config_or_none: dict[str, Any] | None = self.merge_ssl_with_cnf(ssl_config, cnf) + + # default_ssl_ca_path is not represented in my.cnf + if 'default_ssl_ca_path' in self.config['connection'] and (not ssl_config_or_none or not ssl_config_or_none.get('capath')): + if ssl_config_or_none is None: + ssl_config_or_none = {} + ssl_config_or_none['capath'] = self.config['connection']['default_ssl_ca_path'] or False - # Connect to the database. + # prune lone check_hostname=False + if not any(v for v in ssl_config.values()): + ssl_config_or_none = None + + # password hierarchy + # 1. -p / --pass/--password CLI options + # 2. --password-file CLI option + # 3. envvar (MYSQL_PWD) + # 4. DSN (mysql://user:password) + # 5. cnf (.my.cnf / etc) + # 6. keyring + + keyring_identifier = f'{user}@{host}:{"" if socket else int_port}:{socket or ""}' + keyring_domain = 'mycli.net' + keyring_retrieved_cleanly = False + + if passwd is None and use_keyring and not reset_keyring: + passwd = keyring.get_password(keyring_domain, keyring_identifier) + if passwd is not None: + keyring_retrieved_cleanly = True + + # prompt for password if requested by user + if passwd == EMPTY_PASSWORD_FLAG_SENTINEL: + passwd = click.prompt(f"Enter password for {user}", hide_input=True, show_default=False, default='', type=str, err=True) + keyring_retrieved_cleanly = False + + # should not fail, but will help the typechecker + assert not isinstance(passwd, int) + + connection_info: dict[Any, Any] = { + "database": database, + "user": user, + "password": passwd, + "host": host, + "port": int_port, + "socket": socket, + "character_set": character_set, + "local_infile": use_local_infile, + "ssl": ssl_config_or_none, + "ssh_user": ssh_user, + "ssh_host": ssh_host, + "ssh_port": int(ssh_port) if ssh_port else None, + "ssh_password": ssh_password, + "ssh_key_filename": ssh_key_filename, + "init_command": init_command, + "unbuffered": unbuffered, + } - def _connect(): + def _update_keyring(password: str | None, keyring_retrieved_cleanly: bool): + if not password: + return + if reset_keyring or (use_keyring and not keyring_retrieved_cleanly): + try: + saved_pw = keyring.get_password(keyring_domain, keyring_identifier) + if password != saved_pw or reset_keyring: + keyring.set_password(keyring_domain, keyring_identifier, password) + click.secho(f'Password saved to the system keyring at {keyring_domain}/{keyring_identifier}', err=True) + except Exception as e: + click.secho(f'Password not saved to the system keyring: {e}', err=True, fg='red') + + def _connect( + retry_ssl: bool = False, + retry_password: bool = False, + keyring_save_eligible: bool = True, + keyring_retrieved_cleanly: bool = False, + ) -> None: try: - self.sqlexecute = SQLExecute( - database, user, passwd, host, port, socket, charset, - local_infile, ssl, ssh_user, ssh_host, ssh_port, - ssh_password, ssh_key_filename, init_command - ) - except OperationalError as e: - if e.args[0] == ERROR_CODE_ACCESS_DENIED: - if password_from_file: - new_passwd = password_from_file - else: - new_passwd = click.prompt('Password', hide_input=True, - show_default=False, type=str, err=True) - self.sqlexecute = SQLExecute( - database, user, new_passwd, host, port, socket, - charset, local_infile, ssl, ssh_user, ssh_host, - ssh_port, ssh_password, ssh_key_filename, init_command + if keyring_save_eligible: + _update_keyring(connection_info["password"], keyring_retrieved_cleanly=keyring_retrieved_cleanly) + self.sqlexecute = SQLExecute(**connection_info) + except pymysql.OperationalError as e1: + if e1.args[0] == HANDSHAKE_ERROR and ssl is not None and ssl.get("mode", None) == "auto": + # if we already tried and failed to connect without SSL, raise the error + if retry_ssl: + raise e1 + # disable SSL and try to connect again + connection_info["ssl"] = None + _connect( + retry_ssl=True, keyring_retrieved_cleanly=keyring_retrieved_cleanly, keyring_save_eligible=keyring_save_eligible + ) + elif e1.args[0] == ACCESS_DENIED_ERROR and connection_info["password"] is None: + # if we already tried and failed to connect with a new password, raise the error + if retry_password: + raise e1 + # ask the user for a new password and try to connect again + new_password = click.prompt( + f"Enter password for {user}", hide_input=True, show_default=False, default='', type=str, err=True + ) + connection_info["password"] = new_password + keyring_retrieved_cleanly = False + _connect( + retry_password=True, + keyring_retrieved_cleanly=keyring_retrieved_cleanly, + keyring_save_eligible=keyring_save_eligible, ) + elif e1.args[0] == ER_MUST_CHANGE_PASSWORD_LOGIN: + self.echo( + "Your password has expired and the server rejected the connection.", + err=True, + fg='red', + ) + raise e1 + elif e1.args[0] == CR_SERVER_LOST: + self.echo( + ( + "Connection to server lost. If this error persists, it may be a mismatch between the server and " + "client SSL configuration. To troubleshoot the issue, try --ssl-mode=off or --ssl-mode=on." + ), + err=True, + fg='red', + ) + raise e1 else: - raise e + raise e1 try: if not WIN and socket: - socket_owner = getpwuid(os.stat(socket).st_uid).pw_name - self.echo( - f"Connecting to socket {socket}, owned by user {socket_owner}", err=True) try: - _connect() - except OperationalError as e: + socket_owner = getpwuid(os.stat(socket).st_uid).pw_name + except KeyError: + socket_owner = '' + self.echo(f"Connecting to socket {socket}, owned by user {socket_owner}", err=True) + try: + _connect(keyring_retrieved_cleanly=keyring_retrieved_cleanly) + except pymysql.OperationalError as e: # These are "Can't open socket" and 2x "Can't connect" if [code for code in (2001, 2002, 2003) if code == e.args[0]]: - self.logger.debug('Database connection failed: %r.', e) - self.logger.error( - "traceback: %r", traceback.format_exc()) - self.logger.debug('Retrying over TCP/IP') - self.echo( - "Failed to connect to local MySQL server through socket '{}':".format(socket)) + self.logger.debug("Database connection failed: %r.", e) + self.logger.error("traceback: %r", traceback.format_exc()) + self.logger.debug("Retrying over TCP/IP") + self.echo(f"Failed to connect to local MySQL server through socket '{socket}':") self.echo(str(e), err=True) - self.echo( - 'Retrying over TCP/IP', err=True) + self.echo("Retrying over TCP/IP", err=True) # Else fall back to TCP/IP localhost socket = "" - host = 'localhost' - port = 3306 - _connect() + host = DEFAULT_HOST + port = DEFAULT_PORT + # todo should reload the keyring identifier here instead of invalidating + _connect(keyring_save_eligible=False) else: raise e else: - host = host or 'localhost' - port = port or 3306 + host = host or DEFAULT_HOST + port = port or DEFAULT_PORT + # could try loading the keyring again here instead of assuming nothing important changed # Bad ports give particularly daft error messages try: port = int(port) - except ValueError as e: - self.echo("Error: Invalid port number: '{0}'.".format(port), - err=True, fg='red') - exit(1) + except ValueError: + self.echo(f"Error: Invalid port number: '{port}'.", err=True, fg="red") + sys.exit(1) - _connect() + _connect(keyring_retrieved_cleanly=keyring_retrieved_cleanly) + + # Check if SQLExecute detected sandbox mode during connection + if self.sqlexecute and self.sqlexecute.sandbox_mode: + self.sandbox_mode = True + self.echo( + "Your password has expired. Use ALTER USER or SET PASSWSORD to set a new password, or quit.", + err=True, + fg='yellow', + ) except Exception as e: # Connecting to a database could fail. - self.logger.debug('Database connection failed: %r.', e) + self.logger.debug("Database connection failed: %r.", e) self.logger.error("traceback: %r", traceback.format_exc()) - self.echo(str(e), err=True, fg='red') - exit(1) - - def get_password_from_file(self, password_file): - password_from_file = None - if password_file: - if (os.path.isfile(password_file) or stat.S_ISFIFO(os.stat(password_file).st_mode)) \ - and os.access(password_file, os.R_OK): - with open(password_file) as fp: - password_from_file = fp.readline() - password_from_file = password_from_file.rstrip().lstrip() - - return password_from_file - - def handle_editor_command(self, text): - r"""Editor command is any query that is prefixed or suffixed by a '\e'. - The reason for a while loop is because a user might edit a query - multiple times. For eg: - - "select * from \e" to edit it in vim, then come - back to the prompt with the edited query "select * from - blah where q = 'abc'\e" to edit it again. - :param text: Document - :return: Document - - """ - - while special.editor_command(text): - filename = special.get_filename(text) - query = (special.get_editor_query(text) or - self.get_last_query()) - sql, message = special.open_external_editor(filename, sql=query) - if message: - # Something went wrong. Raise an exception and bail. - raise RuntimeError(message) - while True: - try: - text = self.prompt_app.prompt(default=sql) - break - except KeyboardInterrupt: - sql = "" - - continue - return text - - def handle_clip_command(self, text): - r"""A clip command is any query that is prefixed or suffixed by a - '\clip'. + self.echo(str(e), err=True, fg="red") + sys.exit(1) - :param text: Document - :return: Boolean + def run_cli(self) -> None: + main_repl(self) + def reconnect(self, database: str = "") -> bool: """ + Attempt to reconnect to the server. Return True if successful, + False if unsuccessful. - if special.clip_command(text): - query = (special.get_clip_query(text) or - self.get_last_query()) - message = special.copy_query_to_clipboard(sql=query) - if message: - raise RuntimeError(message) - return True - return False + The "database" argument is used only to improve messages. + """ + assert self.sqlexecute is not None + assert self.sqlexecute.conn is not None - def handle_prettify_binding(self, text): + # First pass with ping(reconnect=False) and minimal feedback levels. This definitely + # works as expected, and is a good idea especially when "connect" was used as a + # synonym for "use". try: - statements = sqlglot.parse(text, read='mysql') - except Exception as e: - statements = [] - if len(statements) == 1 and statements[0]: - pretty_text = statements[0].sql(pretty=True, pad=4, dialect='mysql') - else: - pretty_text = '' - self.toolbar_error_message = 'Prettify failed to parse statement' - if len(pretty_text) > 0: - pretty_text = pretty_text + ';' - return pretty_text + self.sqlexecute.conn.ping(reconnect=False) + if not database: + self.echo("Already connected.", fg="yellow") + return True + except pymysql.err.Error: + pass - def handle_unprettify_binding(self, text): + # Second pass with ping(reconnect=True). It is not demonstrated that this pass ever + # gives the benefit it is looking for, _ie_ preserves session state. We need to test + # this with connection pooling. try: - statements = sqlglot.parse(text, read='mysql') - except Exception as e: - statements = [] - if len(statements) == 1 and statements[0]: - unpretty_text = statements[0].sql(pretty=False, dialect='mysql') - else: - unpretty_text = '' - self.toolbar_error_message = 'Unprettify failed to parse statement' - if len(unpretty_text) > 0: - unpretty_text = unpretty_text + ';' - return unpretty_text - - def run_cli(self): - iterations = 0 - sqlexecute = self.sqlexecute - logger = self.logger - self.configure_pager() - - if self.smart_completion: - self.refresh_completions() - - history_file = os.path.expanduser( - os.environ.get('MYCLI_HISTFILE', '~/.mycli-history')) - if dir_path_exists(history_file): - history = FileHistoryWithTimestamp(history_file) - else: - history = None - self.echo( - 'Error: Unable to open the history file "{}". ' - 'Your query history will not be saved.'.format(history_file), - err=True, fg='red') - - key_bindings = mycli_bindings(self) - - if not self.less_chatty: - print(sqlexecute.server_info) - print('mycli', __version__) - print(SUPPORT_INFO) - print('Thanks to the contributor -', thanks_picker()) - - def get_message(): - prompt = self.get_prompt(self.prompt_format) - if self.prompt_format == self.default_prompt and len(prompt) > self.max_len_prompt: - prompt = self.get_prompt(self.default_prompt_splitln) - prompt = prompt.replace("\\x1b", "\x1b") - return ANSI(prompt) - - def get_continuation(width, *_): - if self.multiline_continuation_char == '': - continuation = '' - elif self.multiline_continuation_char: - left_padding = width - len(self.multiline_continuation_char) - continuation = " " * \ - max((left_padding - 1), 0) + \ - self.multiline_continuation_char + " " - else: - continuation = " " - return [('class:continuation', continuation)] - - def show_suggestion_tip(): - return iterations < 2 - - def one_iteration(text=None): - if text is None: - try: - text = self.prompt_app.prompt() - except KeyboardInterrupt: - return - - special.set_expanded_output(False) - - try: - text = self.handle_editor_command(text) - except RuntimeError as e: - logger.error("sql: %r, error: %r", text, e) - logger.error("traceback: %r", traceback.format_exc()) - self.echo(str(e), err=True, fg='red') - return - - try: - if self.handle_clip_command(text): - return - except RuntimeError as e: - logger.error("sql: %r, error: %r", text, e) - logger.error("traceback: %r", traceback.format_exc()) - self.echo(str(e), err=True, fg='red') - return - - if not text.strip(): - return - - if self.destructive_warning: - destroy = confirm_destructive_query(text) - if destroy is None: - pass # Query was not destructive. Nothing to do here. - elif destroy is True: - self.echo('Your call!') - else: - self.echo('Wise choice!') - return - else: - destroy = True - - # Keep track of whether or not the query is mutating. In case - # of a multi-statement query, the overall query is considered - # mutating if any one of the component statements is mutating - mutating = False - - try: - logger.debug('sql: %r', text) - - special.write_tee(self.get_prompt(self.prompt_format) + text) - if self.logfile: - self.logfile.write('\n# %s\n' % datetime.now()) - self.logfile.write(text) - self.logfile.write('\n') - - successful = False - start = time() - res = sqlexecute.run(text) - self.formatter.query = text - successful = True - result_count = 0 - for title, cur, headers, status in res: - logger.debug("headers: %r", headers) - logger.debug("rows: %r", cur) - logger.debug("status: %r", status) - threshold = 1000 - if (is_select(status) and - cur and cur.rowcount > threshold): - self.echo('The result set has more than {} rows.'.format( - threshold), fg='red') - if not confirm('Do you want to continue?'): - self.echo("Aborted!", err=True, fg='red') - break - - if self.auto_vertical_output: - max_width = self.prompt_app.output.get_size().columns - else: - max_width = None - - formatted = self.format_output( - title, cur, headers, special.is_expanded_output(), - max_width) - - t = time() - start - try: - if result_count > 0: - self.echo('') - try: - self.output(formatted, status) - except KeyboardInterrupt: - pass - if self.beep_after_seconds > 0 and t >= self.beep_after_seconds: - self.bell() - if special.is_timing_enabled(): - self.echo('Time: %0.03fs' % t) - except KeyboardInterrupt: - pass - - start = time() - result_count += 1 - mutating = mutating or destroy or is_mutating(status) - special.unset_once_if_written() - special.unset_pipe_once_if_written() - except EOFError as e: - raise e - except KeyboardInterrupt: - # get last connection id - connection_id_to_kill = sqlexecute.connection_id - # some mysql compatible databases may not implemente connection_id() - if connection_id_to_kill > 0: - logger.debug("connection id to kill: %r", connection_id_to_kill) - # Restart connection to the database - sqlexecute.connect() - try: - for title, cur, headers, status in sqlexecute.run('kill %s' % connection_id_to_kill): - status_str = str(status).lower() - if status_str.find('ok') > -1: - logger.debug("cancelled query, connection id: %r, sql: %r", - connection_id_to_kill, text) - self.echo("cancelled query", err=True, fg='red') - except Exception as e: - self.echo('Encountered error while cancelling query: {}'.format(e), - err=True, fg='red') - else: - logger.debug("Did not get a connection id, skip cancelling query") - except NotImplementedError: - self.echo('Not Yet Implemented.', fg="yellow") - except OperationalError as e: - logger.debug("Exception: %r", e) - if (e.args[0] in (2003, 2006, 2013)): - logger.debug('Attempting to reconnect.') - self.echo('Reconnecting...', fg='yellow') - try: - sqlexecute.connect() - logger.debug('Reconnected successfully.') - one_iteration(text) - return # OK to just return, cuz the recursion call runs to the end. - except OperationalError as e: - logger.debug('Reconnect failed. e: %r', e) - self.echo(str(e), err=True, fg='red') - # If reconnection failed, don't proceed further. - return - else: - logger.error("sql: %r, error: %r", text, e) - logger.error("traceback: %r", traceback.format_exc()) - self.echo(str(e), err=True, fg='red') - except Exception as e: - logger.error("sql: %r, error: %r", text, e) - logger.error("traceback: %r", traceback.format_exc()) - self.echo(str(e), err=True, fg='red') - else: - if is_dropping_database(text, self.sqlexecute.dbname): - self.sqlexecute.dbname = None - self.sqlexecute.connect() - - # Refresh the table names and column names if necessary. - if need_completion_refresh(text): - self.refresh_completions( - reset=need_completion_reset(text)) - finally: - if self.logfile is False: - self.echo("Warning: This query was not logged.", - err=True, fg='red') - query = Query(text, successful, mutating) - self.query_history.append(query) - - get_toolbar_tokens = create_toolbar_tokens_func( - self, show_suggestion_tip) - if self.wider_completion_menu: - complete_style = CompleteStyle.MULTI_COLUMN - else: - complete_style = CompleteStyle.COLUMN - - with self._completer_lock: - - if self.key_bindings == 'vi': - editing_mode = EditingMode.VI - else: - editing_mode = EditingMode.EMACS - - self.prompt_app = PromptSession( - lexer=PygmentsLexer(MyCliLexer), - reserve_space_for_menu=self.get_reserved_space(), - message=get_message, - prompt_continuation=get_continuation, - bottom_toolbar=get_toolbar_tokens, - complete_style=complete_style, - input_processors=[ConditionalProcessor( - processor=HighlightMatchingBracketProcessor( - chars='[](){}'), - filter=HasFocus(DEFAULT_BUFFER) & ~IsDone() - )], - tempfile_suffix='.sql', - completer=DynamicCompleter(lambda: self.completer), - history=history, - auto_suggest=AutoSuggestFromHistory(), - complete_while_typing=True, - multiline=cli_is_multiline(self), - style=style_factory(self.syntax_style, self.cli_style), - include_default_pygments_style=False, - key_bindings=key_bindings, - enable_open_in_editor=True, - enable_system_prompt=True, - enable_suspend=True, - editing_mode=editing_mode, - search_ignore_case=True - ) + old_connection_id = self.sqlexecute.connection_id + self.logger.debug("Attempting to reconnect.") + self.echo("Reconnecting...", fg="yellow") + self.sqlexecute.conn.ping(reconnect=True) + # if a database is currently selected, set it on the conn again + if self.sqlexecute.dbname: + self.sqlexecute.conn.select_db(self.sqlexecute.dbname) + self.logger.debug("Reconnected successfully.") + self.echo("Reconnected successfully.", fg="yellow") + self.sqlexecute.reset_connection_id() + if old_connection_id != self.sqlexecute.connection_id: + self.echo("Any session state was reset.", fg="red") + return True + except pymysql.err.Error: + pass + # Third pass with sqlexecute.connect() should always work, but always resets session state. try: - while True: - one_iteration() - iterations += 1 - except EOFError: - special.close_tee() - if not self.less_chatty: - self.echo('Goodbye!') - - def log_output(self, output): - """Log the output in the audit log, if it's enabled.""" - if self.logfile: - click.echo(output, file=self.logfile) - - def echo(self, s, **kwargs): - """Print a message to stdout. - - The message will be logged in the audit log, if enabled. - - All keyword arguments are passed to click.echo(). - - """ - self.log_output(s) - click.secho(s, **kwargs) - - def bell(self): - """Print a bell on the stderr. - """ - click.secho('\a', err=True, nl=False) - - def get_output_margin(self, status=None): - """Get the output margin (number of rows for the prompt, footer and - timing message.""" - margin = self.get_reserved_space() + self.get_prompt(self.prompt_format).count('\n') + 1 - if special.is_timing_enabled(): - margin += 1 - if status: - margin += 1 + status.count('\n') - - return margin - - - def output(self, output, status=None): - """Output text to stdout or a pager command. - - The status text is not outputted to pager or files. - - The message will be logged in the audit log, if enabled. The - message will be written to the tee file, if enabled. The - message will be written to the output file, if enabled. - - """ - if output: - size = self.prompt_app.output.get_size() - - margin = self.get_output_margin(status) - - fits = True - buf = [] - output_via_pager = self.explicit_pager and special.is_pager_enabled() - for i, line in enumerate(output, 1): - self.log_output(line) - special.write_tee(line) - special.write_once(line) - special.write_pipe_once(line) - - if fits or output_via_pager: - # buffering - buf.append(line) - if len(line) > size.columns or i > (size.rows - margin): - fits = False - if not self.explicit_pager and special.is_pager_enabled(): - # doesn't fit, use pager - output_via_pager = True - - if not output_via_pager: - # doesn't fit, flush buffer - for buf_line in buf: - click.secho(buf_line) - buf = [] - else: - click.secho(line) - - if buf: - if output_via_pager: - def newlinewrapper(text): - for line in text: - yield line + "\n" - click.echo_via_pager(newlinewrapper(buf)) - else: - for line in buf: - click.secho(line) - - if status: - self.log_output(status) - click.secho(status) - - def configure_pager(self): - # Provide sane defaults for less if they are empty. - if not os.environ.get('LESS'): - os.environ['LESS'] = '-RXF' - - cnf = self.read_my_cnf_files(self.cnf_files, ['pager', 'skip-pager']) - cnf_pager = cnf['pager'] or self.config['main']['pager'] - if cnf_pager: - special.set_pager(cnf_pager) - self.explicit_pager = True - else: - self.explicit_pager = False + self.logger.debug("Creating new connection") + self.echo("Creating new connection...", fg="yellow") + self.sqlexecute.connect() + self.logger.debug("New connection created successfully.") + self.echo("New connection created successfully.", fg="yellow") + self.echo("Any session state was reset.", fg="red") + return True + except pymysql.OperationalError as e: + self.logger.debug("Reconnect failed. e: %r", e) + self.echo(str(e), err=True, fg="red") + return False - if cnf['skip-pager'] or not self.config['main'].as_bool('enable_pager'): - special.disable_pager() + def refresh_completions(self, reset: bool = False) -> list[SQLResult]: + # Cancel any in-flight schema prefetch before the completer is + # replaced. Loaded-schema bookkeeping is intentionally preserved + # so switching between already-loaded schemas does not re-fetch. + self.schema_prefetcher.stop() - def refresh_completions(self, reset=False): + assert self.sqlexecute is not None if reset: + # Update the active completer's current-schema pointer right + # away so unqualified completions reflect a schema switch + # even before the background refresh finishes. with self._completer_lock: - self.completer.reset_completions() + self.completer.set_dbname(self.sqlexecute.dbname) self.completion_refresher.refresh( - self.sqlexecute, self._on_completions_refreshed, - {'smart_completion': self.smart_completion, - 'supported_formats': self.formatter.supported_formats, - 'keyword_casing': self.completer.keyword_casing}) + self.sqlexecute, + self._on_completions_refreshed, + { + "smart_completion": self.smart_completion, + "supported_formats": self.main_formatter.supported_formats, + "keyword_casing": self.completer.keyword_casing, + }, + ) - return [(None, None, None, - 'Auto-completion refresh started in the background.')] + return [SQLResult(status="Auto-completion refresh started in the background.")] - def _on_completions_refreshed(self, new_completer): - """Swap the completer object in cli with the newly created completer. - """ + def _on_completions_refreshed(self, new_completer: SQLCompleter) -> None: + """Swap the completer object in cli with the newly created completer.""" with self._completer_lock: + new_completer.copy_other_schemas_from(self.completer, exclude=new_completer.dbname) self.completer = new_completer - if self.prompt_app: + if self.prompt_session: # After refreshing, redraw the CLI to clear the statusbar # "Refreshing completions..." indicator - self.prompt_app.app.invalidate() - - def get_completions(self, text, cursor_positition): - with self._completer_lock: - return self.completer.get_completions( - Document(text=text, cursor_position=cursor_positition), None) - - def get_prompt(self, string): - sqlexecute = self.sqlexecute - host = self.login_path if self.login_path and self.login_path_as_host else sqlexecute.host - now = datetime.now() - string = string.replace('\\u', sqlexecute.user or '(none)') - string = string.replace('\\h', host or '(none)') - string = string.replace('\\d', sqlexecute.dbname or '(none)') - string = string.replace('\\t', sqlexecute.server_info.species.name) - string = string.replace('\\n', "\n") - string = string.replace('\\D', now.strftime('%a %b %d %H:%M:%S %Y')) - string = string.replace('\\m', now.strftime('%M')) - string = string.replace('\\P', now.strftime('%p')) - string = string.replace('\\R', now.strftime('%H')) - string = string.replace('\\r', now.strftime('%I')) - string = string.replace('\\s', now.strftime('%S')) - string = string.replace('\\p', str(sqlexecute.port)) - string = string.replace('\\A', self.dsn_alias or '(none)') - string = string.replace('\\_', ' ') - return string - - def run_query(self, query, new_line=True): + self.prompt_session.app.invalidate() + + # Kick off background prefetch for any extra schemas configured + # via ``prefetch_schemas_mode`` so users get cross-schema completions. + self.schema_prefetcher.start_configured() + + def run_query( + self, + query: str, + checkpoint: TextIOWrapper | None = None, + new_line: bool = True, + ) -> None: """Runs *query*.""" + assert self.sqlexecute is not None + self.log_query(query) results = self.sqlexecute.run(query) for result in results: - title, cur, headers, status = result - self.formatter.query = query - output = self.format_output(title, cur, headers, special.is_expanded_output()) + self.main_formatter.query = query + self.redirect_formatter.query = query + output = self.format_sqlresult( + result, + is_expanded=special.is_expanded_output(), + is_redirected=special.is_redirected(), + null_string=self.null_string, + numeric_alignment=self.numeric_alignment, + binary_display=self.binary_display, + ) for line in output: + self.log_output(line) click.echo(line, nl=new_line) - def format_output(self, title, cur, headers, expanded=False, - max_width=None): - expanded = expanded or self.formatter.format_name == 'vertical' - output = [] - - output_kwargs = { - 'dialect': 'unix', - 'disable_numparse': True, - 'preserve_whitespace': True, - 'style': self.output_style - } - - if not self.formatter.format_name in sql_format.supported_formats: - output_kwargs["preprocessors"] = (preprocessors.align_decimals, ) - - if title: # Only print the title if it's not None. - output = itertools.chain(output, [title]) - - if cur: - column_types = None - if hasattr(cur, 'description'): - def get_col_type(col): - col_type = FIELD_TYPES.get(col[1], str) - return col_type if type(col_type) is type else str - column_types = [get_col_type(col) for col in cur.description] - - if max_width is not None: - cur = list(cur) - - formatted = self.formatter.format_output( - cur, headers, format_name='vertical' if expanded else None, - column_types=column_types, - **output_kwargs) - - if isinstance(formatted, str): - formatted = formatted.splitlines() - formatted = iter(formatted) - - if (not expanded and max_width and headers and cur): - first_line = next(formatted) - if len(strip_ansi(first_line)) > max_width: - formatted = self.formatter.format_output( - cur, headers, format_name='vertical', column_types=column_types, **output_kwargs) - if isinstance(formatted, str): - formatted = iter(formatted.splitlines()) - else: - formatted = itertools.chain([first_line], formatted) - - output = itertools.chain(output, formatted) - - - return output - - def get_reserved_space(self): - """Get the number of lines to reserve for the completion menu.""" - reserved_space_ratio = .45 - max_reserved_space = 8 - _, height = shutil.get_terminal_size() - return min(int(round(height * reserved_space_ratio)), max_reserved_space) + # get and display warnings if enabled + if special.is_show_warnings_enabled() and isinstance(result.rows, Cursor) and result.rows.warning_count > 0: + warnings = self.sqlexecute.run("SHOW WARNINGS") + for warning in warnings: + output = self.format_sqlresult( + warning, + is_expanded=special.is_expanded_output(), + is_redirected=special.is_redirected(), + null_string=self.null_string, + numeric_alignment=self.numeric_alignment, + binary_display=self.binary_display, + is_warnings_style=True, + ) + for line in output: + click.echo(line, nl=new_line) + if checkpoint: + checkpoint.write(query.rstrip('\n') + '\n') + checkpoint.flush() - def get_last_query(self): + def get_last_query(self) -> str | None: """Get the last query executed or None.""" return self.query_history[-1][0] if self.query_history else None @click.command() -@click.option('-h', '--host', envvar='MYSQL_HOST', help='Host address of the database.') -@click.option('-P', '--port', envvar='MYSQL_TCP_PORT', type=int, help='Port number to use for connection. Honors ' - '$MYSQL_TCP_PORT.') -@click.option('-u', '--user', help='User name to connect to the database.') -@click.option('-S', '--socket', envvar='MYSQL_UNIX_PORT', help='The socket file to use for connection.') -@click.option('-p', '--password', 'password', envvar='MYSQL_PWD', type=str, - help='Password to connect to the database.') -@click.option('--pass', 'password', envvar='MYSQL_PWD', type=str, - help='Password to connect to the database.') -@click.option('--ssh-user', help='User name to connect to ssh server.') -@click.option('--ssh-host', help='Host name to connect to ssh server.') -@click.option('--ssh-port', default=22, help='Port to connect to ssh server.') -@click.option('--ssh-password', help='Password to connect to ssh server.') -@click.option('--ssh-key-filename', help='Private key filename (identify file) for the ssh connection.') -@click.option('--ssh-config-path', help='Path to ssh configuration.', - default=os.path.expanduser('~') + '/.ssh/config') -@click.option('--ssh-config-host', help='Host to connect to ssh server reading from ssh configuration.') -@click.option('--ssl', 'ssl_enable', is_flag=True, - help='Enable SSL for connection (automatically enabled with other flags).') -@click.option('--ssl-ca', help='CA file in PEM format.', - type=click.Path(exists=True)) -@click.option('--ssl-capath', help='CA directory.') -@click.option('--ssl-cert', help='X509 cert in PEM format.', - type=click.Path(exists=True)) -@click.option('--ssl-key', help='X509 key in PEM format.', - type=click.Path(exists=True)) -@click.option('--ssl-cipher', help='SSL cipher to use.') -@click.option('--tls-version', - type=click.Choice(['TLSv1', 'TLSv1.1', 'TLSv1.2', 'TLSv1.3'], case_sensitive=False), - help='TLS protocol version for secure connection.') -@click.option('--ssl-verify-server-cert', is_flag=True, - help=('Verify server\'s "Common Name" in its cert against ' - 'hostname used when connecting. This option is disabled ' - 'by default.')) -# as of 2016-02-15 revocation list is not supported by underling PyMySQL -# library (--ssl-crl and --ssl-crlpath options in vanilla mysql client) -@click.option('-V', '--version', is_flag=True, help='Output mycli\'s version.') -@click.option('-v', '--verbose', is_flag=True, help='Verbose output.') -@click.option('-D', '--database', 'dbname', help='Database to use.') -@click.option('-d', '--dsn', default='', envvar='DSN', - help='Use DSN configured into the [alias_dsn] section of myclirc file.') -@click.option('--list-dsn', 'list_dsn', is_flag=True, - help='list of DSN configured into the [alias_dsn] section of myclirc file.') -@click.option('--list-ssh-config', 'list_ssh_config', is_flag=True, - help='list ssh configurations in the ssh config (requires paramiko).') -@click.option('-R', '--prompt', 'prompt', - help='Prompt format (Default: "{0}").'.format( - MyCli.default_prompt)) -@click.option('-l', '--logfile', type=click.File(mode='a', encoding='utf-8'), - help='Log every query and its results to a file.') -@click.option('--defaults-group-suffix', type=str, - help='Read MySQL config groups with the specified suffix.') -@click.option('--defaults-file', type=click.Path(), - help='Only read MySQL options from the given file.') -@click.option('--myclirc', type=click.Path(), default="~/.myclirc", - help='Location of myclirc file.') -@click.option('--auto-vertical-output', is_flag=True, - help='Automatically switch to vertical output mode if the result is wider than the terminal width.') -@click.option('-t', '--table', is_flag=True, - help='Display batch output in table format.') -@click.option('--csv', is_flag=True, - help='Display batch output in CSV format.') -@click.option('--warn/--no-warn', default=None, - help='Warn before running a destructive query.') -@click.option('--local-infile', type=bool, - help='Enable/disable LOAD DATA LOCAL INFILE.') -@click.option('-g', '--login-path', type=str, - help='Read this path from the login file.') -@click.option('-e', '--execute', type=str, - help='Execute command and quit.') -@click.option('--init-command', type=str, - help='SQL statement to execute after connecting.') -@click.option('--charset', type=str, - help='Character set for MySQL session.') -@click.option('--password-file', type=click.Path(), - help='File or FIFO path containing the password to connect to the db if not specified otherwise.') -@click.argument('database', default='', nargs=1) -def cli(database, user, host, port, socket, password, dbname, - version, verbose, prompt, logfile, defaults_group_suffix, - defaults_file, login_path, auto_vertical_output, local_infile, - ssl_enable, ssl_ca, ssl_capath, ssl_cert, ssl_key, ssl_cipher, - tls_version, ssl_verify_server_cert, table, csv, warn, execute, - myclirc, dsn, list_dsn, ssh_user, ssh_host, ssh_port, ssh_password, - ssh_key_filename, list_ssh_config, ssh_config_path, ssh_config_host, - init_command, charset, password_file): +@clickdc.adddc('cli_args', CliArgs) +@click.version_option(mycli_package.__version__, '--version', '-V', help="Output mycli's version.") +def click_entrypoint( + cli_args: CliArgs, +) -> None: """A MySQL terminal client with auto-completion and syntax highlighting. \b @@ -1205,275 +897,435 @@ def cli(database, user, host, port, socket, password, dbname, """ - if version: - print('Version:', __version__) - sys.exit(0) + cli_verbosity = preprocess_cli_args(cli_args, is_valid_connection_scheme) + + mycli = MyCli( + prompt=cli_args.prompt, + toolbar_format=cli_args.toolbar, + logfile=cli_args.logfile, + defaults_suffix=cli_args.defaults_group_suffix, + defaults_file=cli_args.defaults_file, + login_path=cli_args.login_path, + auto_vertical_output=cli_args.auto_vertical_output, + warn=cli_args.warn, + myclirc=cli_args.myclirc, + show_warnings=cli_args.show_warnings, + cli_verbosity=cli_verbosity, + ) - mycli = MyCli(prompt=prompt, logfile=logfile, - defaults_suffix=defaults_group_suffix, - defaults_file=defaults_file, login_path=login_path, - auto_vertical_output=auto_vertical_output, warn=warn, - myclirc=myclirc) - if list_dsn: - try: - alias_dsn = mycli.config['alias_dsn'] - except KeyError as err: - click.secho('Invalid DSNs found in the config file. '\ - 'Please check the "[alias_dsn]" section in myclirc.', - err=True, fg='red') - exit(1) - except Exception as e: - click.secho(str(e), err=True, fg='red') - exit(1) - for alias, value in alias_dsn.items(): - if verbose: - click.secho("{} : {}".format(alias, value)) - else: - click.secho(alias) - sys.exit(0) - if list_ssh_config: - ssh_config = read_ssh_config(ssh_config_path) - for host in ssh_config.get_hostnames(): - if verbose: - host_config = ssh_config.lookup(host) - click.secho("{} : {}".format( - host, host_config.get('hostname'))) - else: - click.secho(host) + if cli_args.checkup: + main_checkup(mycli) sys.exit(0) - # Choose which ever one has a valid value. - database = dbname or database - - ssl = { - 'enable': ssl_enable, - 'ca': ssl_ca and os.path.expanduser(ssl_ca), - 'cert': ssl_cert and os.path.expanduser(ssl_cert), - 'key': ssl_key and os.path.expanduser(ssl_key), - 'capath': ssl_capath, - 'cipher': ssl_cipher, - 'tls_version': tls_version, - 'check_hostname': ssl_verify_server_cert, - } - # remove empty ssl options - ssl = {k: v for k, v in ssl.items() if v is not None} + if cli_args.csv and cli_args.format not in [None, 'csv']: + click.secho("Conflicting --csv and --format arguments.", err=True, fg="red") + sys.exit(1) + + if cli_args.table and cli_args.format not in [None, 'table']: + click.secho("Conflicting --table and --format arguments.", err=True, fg="red") + sys.exit(1) + + if not cli_args.format: + cli_args.format = 'default' - dsn_uri = None + if cli_args.csv: + cli_args.format = 'csv' + + if cli_args.table: + cli_args.format = 'table' + + if cli_args.deprecated_ssl is not None: + click.secho( + "Warning: The --ssl/--no-ssl CLI options are deprecated and will be removed in a future release. " + "Please use the \"default_ssl_mode\" config option or --ssl-mode CLI flag instead. " + f"See issue {ISSUES_URL}/1507", + err=True, + fg="yellow", + ) + + # ssh_port and ssh_config_path have truthy defaults and are not included + if ( + any([ + cli_args.ssh_user, + cli_args.ssh_host, + cli_args.ssh_password, + cli_args.ssh_key_filename, + cli_args.list_ssh_config, + cli_args.ssh_config_host, + ]) + and not cli_args.ssh_warning_off + ): + click.secho( + f"Warning: The built-in SSH functionality is deprecated and will be removed in a future release. See issue {ISSUES_URL}/1464", + err=True, + fg="red", + ) + + if cli_args.list_dsn: + sys.exit(main_list_dsn(mycli)) - # Treat the database argument as a DSN alias if we're missing - # other connection information. - if (mycli.config['alias_dsn'] and database and '://' not in database - and not any([user, password, host, port, login_path])): - dsn, database = database, '' + if cli_args.list_ssh_config: + sys.exit(main_list_ssh_config(mycli, cli_args)) - if database and '://' in database: - dsn_uri, database = database, '' + if 'MYSQL_UNIX_PORT' in os.environ: + # deprecated 2026-03 + click.secho( + "The MYSQL_UNIX_PORT environment variable is deprecated in favor of MYSQL_UNIX_SOCKET. " + "MYSQL_UNIX_PORT will be removed in a future release.", + err=True, + fg="red", + ) + if not cli_args.socket: + cli_args.socket = os.environ['MYSQL_UNIX_PORT'] + + if 'DSN' in os.environ: + # deprecated 2026-03 + click.secho( + "The DSN environment variable is deprecated in favor of MYSQL_DSN. Support for DSN will be removed in a future release.", + err=True, + fg="red", + ) + if not cli_args.dsn: + cli_args.dsn = os.environ['DSN'] + + # Choose which ever one has a valid value. + database = cli_args.dbname or cli_args.database + + dsn_uri = None - if dsn: + # Treat the database argument as a DSN alias only if it matches a configured alias + # todo why is port tested but not socket? + truthy_password = cli_args.password not in (None, EMPTY_PASSWORD_FLAG_SENTINEL) + if ( + database + and "://" not in database + and not any([ + cli_args.user, + truthy_password, + cli_args.host, + cli_args.port, + cli_args.login_path, + ]) + and database in mycli.config.get("alias_dsn", {}) + ): + cli_args.dsn, database = database, "" + + if database and "://" in database: + dsn_uri, database = database, "" + + if cli_args.dsn: try: - dsn_uri = mycli.config['alias_dsn'][dsn] + dsn_uri = mycli.config["alias_dsn"][cli_args.dsn] except KeyError: - click.secho('Could not find the specified DSN in the config file. ' - 'Please check the "[alias_dsn]" section in your ' - 'myclirc.', err=True, fg='red') - exit(1) + is_valid_scheme, scheme = is_valid_connection_scheme(cli_args.dsn) + if is_valid_scheme: + dsn_uri = cli_args.dsn + else: + click.secho( + "Could not find the specified DSN in the config file. Please check the \"[alias_dsn]\" section in your myclirc.", + err=True, + fg="red", + ) + sys.exit(1) else: - mycli.dsn_alias = dsn + mycli.dsn_alias = cli_args.dsn if dsn_uri: uri = urlparse(dsn_uri) if not database: database = uri.path[1:] # ignore the leading fwd slash - if not user and uri.username is not None: - user = unquote(uri.username) - if not password and uri.password is not None: - password = unquote(uri.password) - if not host: - host = uri.hostname - if not port: - port = uri.port - - if ssh_config_host: - ssh_config = read_ssh_config( - ssh_config_path - ).lookup(ssh_config_host) - ssh_host = ssh_host if ssh_host else ssh_config.get('hostname') - ssh_user = ssh_user if ssh_user else ssh_config.get('user') - if ssh_config.get('port') and ssh_port == 22: - # port has a default value, overwrite it if it's in the config - ssh_port = int(ssh_config.get('port')) - ssh_key_filename = ssh_key_filename if ssh_key_filename else ssh_config.get( - 'identityfile', [None])[0] - - ssh_key_filename = ssh_key_filename and os.path.expanduser(ssh_key_filename) - - mycli.connect( - database=database, - user=user, - passwd=password, - host=host, - port=port, - socket=socket, - local_infile=local_infile, - ssl=ssl, - ssh_user=ssh_user, - ssh_host=ssh_host, - ssh_port=ssh_port, - ssh_password=ssh_password, - ssh_key_filename=ssh_key_filename, - init_command=init_command, - charset=charset, - password_file=password_file - ) + if not cli_args.user and uri.username is not None: + cli_args.user = unquote(uri.username) + # todo: rationalize the behavior of empty-string passwords here + if not cli_args.password and uri.password is not None: + cli_args.password = unquote(uri.password) + if not cli_args.host: + cli_args.host = uri.hostname + if not cli_args.port: + cli_args.port = uri.port + + if uri.query: + dsn_params = parse_qs(uri.query) + else: + dsn_params = {} - mycli.logger.debug('Launch Params: \n' - '\tdatabase: %r' - '\tuser: %r' - '\thost: %r' - '\tport: %r', database, user, host, port) + if params := dsn_params.get('ssl'): + click.secho( + 'Warning: The "ssl" DSN URI parameter is deprecated and will be removed in a future release. ' + 'Please use the "ssl_mode" parameter instead. ' + f'See issue {ISSUES_URL}/1507', + err=True, + fg='yellow', + ) + if params[0].lower() == 'true': + cli_args.ssl_mode = 'on' + if params := dsn_params.get('ssl_mode'): + cli_args.ssl_mode = cli_args.ssl_mode or params[0] + if params := dsn_params.get('ssl_ca'): + cli_args.ssl_ca = cli_args.ssl_ca or params[0] + cli_args.ssl_mode = cli_args.ssl_mode or 'on' + if params := dsn_params.get('ssl_capath'): + cli_args.ssl_capath = cli_args.ssl_capath or params[0] + cli_args.ssl_mode = cli_args.ssl_mode or 'on' + if params := dsn_params.get('ssl_cert'): + cli_args.ssl_cert = cli_args.ssl_cert or params[0] + cli_args.ssl_mode = cli_args.ssl_mode or 'on' + if params := dsn_params.get('ssl_key'): + cli_args.ssl_key = cli_args.ssl_key or params[0] + cli_args.ssl_mode = cli_args.ssl_mode or 'on' + if params := dsn_params.get('ssl_cipher'): + cli_args.ssl_cipher = cli_args.ssl_cipher or params[0] + cli_args.ssl_mode = cli_args.ssl_mode or 'on' + if params := dsn_params.get('tls_version'): + cli_args.tls_version = cli_args.tls_version or params[0] + cli_args.ssl_mode = cli_args.ssl_mode or 'on' + if params := dsn_params.get('ssl_verify_server_cert'): + cli_args.ssl_verify_server_cert = cli_args.ssl_verify_server_cert or (params[0].lower() == 'true') + cli_args.ssl_mode = cli_args.ssl_mode or 'on' + if params := dsn_params.get('socket'): + cli_args.socket = cli_args.socket or params[0] + if params := dsn_params.get('keepalive_ticks'): + if cli_args.keepalive_ticks is None: + cli_args.keepalive_ticks = int(params[0]) + if params := dsn_params.get('character_set'): + cli_args.character_set = cli_args.character_set or params[0] + + keepalive_ticks = cli_args.keepalive_ticks if cli_args.keepalive_ticks is not None else mycli.default_keepalive_ticks + ssl_mode = cli_args.ssl_mode or mycli.ssl_mode + + # if there is a mismatch between the ssl_mode value and other sources of ssl config, show a warning + # specifically using "is False" to not pickup the case where cli_args.deprecated_ssl is None (not set by the user) + if cli_args.deprecated_ssl and ssl_mode == "off" or cli_args.deprecated_ssl is False and ssl_mode in ("auto", "on"): + click.secho( + f"Warning: The current ssl_mode value of '{ssl_mode}' is overriding the value provided by " + f"either the --ssl/--no-ssl CLI options or a DSN URI parameter (ssl={cli_args.deprecated_ssl}).", + err=True, + fg="yellow", + ) - # --execute argument - if execute: - try: - if csv: - mycli.formatter.format_name = 'csv' - if execute.endswith(r'\G'): - execute = execute[:-2] - elif table: - if execute.endswith(r'\G'): - execute = execute[:-2] - else: - mycli.formatter.format_name = 'tsv' + # configure SSL if ssl_mode is auto/on or if + # cli_args.deprecated_ssl = True (from --ssl or a DSN URI) and ssl_mode is None + if ssl_mode in ("auto", "on") or (cli_args.deprecated_ssl and ssl_mode is None): + if cli_args.socket and ssl_mode == 'auto': + ssl = None + else: + ssl = { + "mode": ssl_mode, + "enable": cli_args.deprecated_ssl, # todo: why is this set at all? + "ca": cli_args.ssl_ca and os.path.expanduser(cli_args.ssl_ca), + "cert": cli_args.ssl_cert and os.path.expanduser(cli_args.ssl_cert), + "key": cli_args.ssl_key and os.path.expanduser(cli_args.ssl_key), + "capath": cli_args.ssl_capath, + "cipher": cli_args.ssl_cipher, + "tls_version": cli_args.tls_version, + "check_hostname": cli_args.ssl_verify_server_cert, + } + # remove empty ssl options + ssl = {k: v for k, v in ssl.items() if v is not None} + else: + ssl = None - mycli.run_query(execute) - exit(0) - except Exception as e: - click.secho(str(e), err=True, fg='red') - exit(1) + if cli_args.ssh_config_host: + ssh_config = read_ssh_config(cli_args.ssh_config_path).lookup(cli_args.ssh_config_host) + ssh_host = cli_args.ssh_host if cli_args.ssh_host else ssh_config.get("hostname") + ssh_user = cli_args.ssh_user if cli_args.ssh_user else ssh_config.get("user") + if ssh_config.get("port") and cli_args.ssh_port == 22: + # port has a default value, overwrite it if it's in the config + ssh_port = int(ssh_config.get("port")) + else: + ssh_port = cli_args.ssh_port + ssh_key_filename = cli_args.ssh_key_filename if cli_args.ssh_key_filename else ssh_config.get("identityfile", [None])[0] + else: + ssh_host = cli_args.ssh_host + ssh_user = cli_args.ssh_user + ssh_port = cli_args.ssh_port + ssh_key_filename = cli_args.ssh_key_filename - if sys.stdin.isatty(): - mycli.run_cli() + ssh_key_filename = ssh_key_filename and os.path.expanduser(ssh_key_filename) + # Merge init-commands: global, DSN-specific, then CLI + init_cmds: list[str] = [] + # 1) Global init-commands + global_section = mycli.config.get("init-commands", {}) + for _, val in global_section.items(): + if isinstance(val, (list, tuple)): + init_cmds.extend(val) + elif val: + init_cmds.append(val) + # 2) DSN-specific init-commands + if cli_args.dsn: + alias_section = mycli.config.get("alias_dsn.init-commands", {}) + if cli_args.dsn in alias_section: + val = alias_section.get(cli_args.dsn) + if isinstance(val, (list, tuple)): + init_cmds.extend(val) + elif val: + init_cmds.append(val) + # 3) CLI-provided init_command + if cli_args.init_command: + init_cmds.append(cli_args.init_command) + + combined_init_cmd = "; ".join(cmd.strip() for cmd in init_cmds if cmd) + + if cli_args.use_keyring is not None and cli_args.use_keyring.lower() == 'reset': + use_keyring = True + reset_keyring = True + elif cli_args.use_keyring is None: + use_keyring = str_to_bool(mycli.config['main'].get('use_keyring', 'False')) + reset_keyring = False else: - stdin = click.get_text_stream('stdin') - try: - stdin_text = stdin.read() - except MemoryError: - click.secho('Failed! Ran out of memory.', err=True, fg='red') - click.secho('You might want to try the official mysql client.', err=True, fg='red') - click.secho('Sorry... :(', err=True, fg='red') - exit(1) - - if mycli.destructive_warning and is_destructive(stdin_text): - try: - sys.stdin = open('/dev/tty') - warn_confirmed = confirm_destructive_query(stdin_text) - except (IOError, OSError): - mycli.logger.warning('Unable to open TTY as stdin.') - if not warn_confirmed: - exit(0) + use_keyring = str_to_bool(cli_args.use_keyring) + reset_keyring = False + + # todo: removeme after a period of transition + for tup in [ + ('client', 'prompt', 'prompt', 'main', 'prompt'), + ('client', 'pager', 'pager', 'main', 'pager'), + ('client', 'skip-pager', 'skip-pager', 'main', 'enable_pager'), + # this is a white lie, because default_character_set can actually be read from the package config + ('client', 'default-character-set', 'default-character-set', 'connection', 'default_character_set'), + # local-infile can be read from both sections + ('mysqld', 'local-infile', 'local-infile', 'connection', 'default_local_infile'), + ('client', 'local-infile', 'local-infile', 'connection', 'default_local_infile'), + ('mysqld', 'loose-local-infile', 'loose-local-infile', 'connection', 'default_local_infile'), + ('client', 'loose-local-infile', 'loose-local-infile', 'connection', 'default_local_infile'), + # todo: in the future we should add default_port, etc, but only in .myclirc + # they are currently ignored in my.cnf + ('mysqld', 'default_socket', 'socket', 'connection', 'default_socket'), + ('client', 'ssl-ca', 'ssl-ca', 'connection', 'default_ssl_ca'), + ('client', 'ssl-cert', 'ssl-cert', 'connection', 'default_ssl_cert'), + ('client', 'ssl-key', 'ssl-key', 'connection', 'default_ssl_key'), + ('client', 'ssl-cipher', 'ssl-cipher', 'connection', 'default_ssl_cipher'), + ('client', 'ssl-verify-server-cert', 'ssl-verify-server-cert', 'connection', 'default_ssl_verify_server_cert'), + ]: + ( + mycnf_section_name, + mycnf_item_name, + printable_mycnf_item_name, + myclirc_section_name, + myclirc_item_name, + ) = tup + if str_to_bool(mycli.config['main'].get('my_cnf_transition_done', 'False')): + break + if ( + mycli.my_cnf[mycnf_section_name].get(mycnf_item_name) is None + and mycli.my_cnf[mycnf_section_name].get(mycnf_item_name.replace('-', '_')) is None + ): + continue + user_section = mycli.config_without_package_defaults.get(myclirc_section_name, {}) + if user_section.get(myclirc_item_name) is None: + cnf_value = mycli.my_cnf[mycnf_section_name].get(mycnf_item_name) + if cnf_value is None: + cnf_value = mycli.my_cnf[mycnf_section_name].get(mycnf_item_name.replace('-', '_')) + click.secho( + dedent( + f""" + Reading configuration from my.cnf files is deprecated. + See {ISSUES_URL}/1490 . + The cause of this message is the following in a my.cnf file without a corresponding + ~/.myclirc entry: - try: - new_line = True + [{mycnf_section_name}] + {printable_mycnf_item_name} = {cnf_value} - if csv: - mycli.formatter.format_name = 'csv' - elif not table: - mycli.formatter.format_name = 'tsv' + To suppress this message, remove the my.cnf item add or the following to ~/.myclirc: - mycli.run_query(stdin_text, new_line=new_line) - exit(0) - except Exception as e: - click.secho(str(e), err=True, fg='red') - exit(1) + [{myclirc_section_name}] + {myclirc_item_name} = + The ~/.myclirc setting will take precedence. In the future, the my.cnf will be ignored. -def need_completion_refresh(queries): - """Determines if the completion needs a refresh by checking if the sql - statement is an alter, create, drop or change db.""" - for query in sqlparse.split(queries): - try: - first_token = query.split()[0] - if first_token.lower() in ('alter', 'create', 'use', '\\r', - '\\u', 'connect', 'drop', 'rename'): - return True - except Exception: - return False + Values are documented at {REPO_URL}/blob/main/mycli/myclirc . An + empty is generally accepted. + To ignore all of this, set -def need_completion_reset(queries): - """Determines if the statement is a database switch such as 'use' or '\\u'. - When a database is changed the existing completions must be reset before we - start the completion refresh for the new database. - """ - for query in sqlparse.split(queries): - try: - first_token = query.split()[0] - if first_token.lower() in ('use', '\\u'): - return True - except Exception: - return False + [main] + my_cnf_transition_done = True + + in ~/.myclirc. + -------- -def is_mutating(status): - """Determines if the statement is mutating based on the status.""" - if not status: - return False + """ + ), + err=True, + fg='yellow', + ) - mutating = set(['insert', 'update', 'delete', 'alter', 'create', 'drop', - 'replace', 'truncate', 'load', 'rename']) - return status.split(None, 1)[0].lower() in mutating + mycli.connect( + database=database, + user=cli_args.user, + passwd=cli_args.password, + host=cli_args.host, + port=cli_args.port, + socket=cli_args.socket, + local_infile=cli_args.local_infile, + ssl=ssl, + ssh_user=ssh_user, + ssh_host=ssh_host, + ssh_port=ssh_port, + ssh_password=cli_args.ssh_password, + ssh_key_filename=ssh_key_filename, + init_command=combined_init_cmd, + unbuffered=cli_args.unbuffered, + character_set=cli_args.character_set, + use_keyring=use_keyring, + reset_keyring=reset_keyring, + keepalive_ticks=keepalive_ticks, + ) + if combined_init_cmd: + click.echo(f"Executing init-command: {combined_init_cmd}", err=True) -def is_select(status): - """Returns true if the first word in status is 'select'.""" - if not status: - return False - return status.split(None, 1)[0].lower() == 'select' + mycli.logger.debug( + "Launch Params: \n\tdatabase: %r\tuser: %r\thost: %r\tport: %r", + database, + cli_args.user, + cli_args.host, + cli_args.port, + ) + if cli_args.execute is not None: + sys.exit(main_execute_from_cli(mycli, cli_args)) -def thanks_picker(): - import mycli - lines = ( - resources.read_text(mycli, 'AUTHORS') + - resources.read_text(mycli, 'SPONSORS') - ).split('\n') + if cli_args.batch is not None and cli_args.batch != '-' and cli_args.progress and sys.stderr.isatty(): + sys.exit(main_batch_with_progress_bar(mycli, cli_args)) - contents = [] - for line in lines: - m = re.match(r'^ *\* (.*)', line) - if m: - contents.append(m.group(1)) - return choice(contents) + if cli_args.batch is not None: + sys.exit(main_batch_without_progress_bar(mycli, cli_args)) + if not sys.stdin.isatty(): + sys.exit(main_batch_from_stdin(mycli, cli_args)) -@prompt_register('edit-and-execute-command') -def edit_and_execute(event): - """Different from the prompt-toolkit default, we want to have a choice not - to execute a query after editing, hence validate_and_handle=False.""" - buff = event.current_buffer - buff.open_in_editor(validate_and_handle=False) + mycli.run_cli() + mycli.close() -def read_ssh_config(ssh_config_path): - ssh_config = paramiko.config.SSHConfig() +def main() -> int | None: try: - with open(ssh_config_path) as f: - ssh_config.parse(f) - except FileNotFoundError as e: - click.secho(str(e), err=True, fg='red') - sys.exit(1) - # Paramiko prior to version 2.7 raises Exception on parse errors. - # In 2.7 it has become paramiko.ssh_exception.SSHException, - # but let's catch everything for compatibility - except Exception as err: - click.secho( - f'Could not parse SSH configuration file {ssh_config_path}:\n{err} ', - err=True, fg='red' + result = click_entrypoint.main( + filtered_sys_argv(), + standalone_mode=False, # disable builtin exception handling + prog_name='mycli', ) + except click.Abort: + print('Aborted!', file=sys.stderr) sys.exit(1) + except BrokenPipeError: + sys.exit(1) + except click.ClickException as e: + e.show() + if hasattr(e, 'exit_code'): + sys.exit(e.exit_code) + else: + sys.exit(2) + if result is None: + return 0 + elif isinstance(result, int): + return result else: - return ssh_config + return 1 if __name__ == "__main__": - cli() + sys.exit(main()) diff --git a/mycli/main_modes/batch.py b/mycli/main_modes/batch.py new file mode 100644 index 00000000..80c0f7d8 --- /dev/null +++ b/mycli/main_modes/batch.py @@ -0,0 +1,201 @@ +from __future__ import annotations + +from io import TextIOWrapper +import os +import sys +import time +from typing import TYPE_CHECKING + +import click +import prompt_toolkit +from prompt_toolkit.shortcuts import ProgressBar +from prompt_toolkit.shortcuts.progress_bar import formatters as progress_bar_formatters +import pymysql + +from mycli.packages.batch_utils import statements_from_filehandle +from mycli.packages.interactive_utils import confirm_destructive_query +from mycli.packages.sql_utils import is_destructive + +if TYPE_CHECKING: + from mycli.main import CliArgs, MyCli + + +class CheckpointReplayError(Exception): + pass + + +def replay_checkpoint_file( + batch_path: str, + checkpoint: TextIOWrapper | None, + resume: bool, +) -> int: + if not resume: + return 0 + + if checkpoint is None: + return 0 + + if batch_path == '-': + raise CheckpointReplayError('--resume is incompatible with reading from the standard input.') + + checkpoint_name = checkpoint.name + checkpoint.flush() + completed_count = 0 + try: + with click.open_file(batch_path) as batch_h, click.open_file(checkpoint_name, mode='r', encoding='utf-8') as checkpoint_h: + try: + batch_gen = statements_from_filehandle(batch_h) + except ValueError as e: + raise CheckpointReplayError(f'Error reading --batch file: {batch_path}: {e}') from None + for checkpoint_statement, _checkpoint_counter in statements_from_filehandle(checkpoint_h): + try: + batch_statement, _batch_counter = next(batch_gen) + except StopIteration: + raise CheckpointReplayError('Checkpoint script longer than batch script.') from None + except ValueError as e: + raise CheckpointReplayError(f'Error reading --batch file: {batch_path}: {e}') from None + if checkpoint_statement != batch_statement: + raise CheckpointReplayError(f'Statement mismatch: {checkpoint_statement}.') + completed_count += 1 + except ValueError as e: + raise CheckpointReplayError(f'Error reading --checkpoint file: {checkpoint.name}: {e}') from None + except FileNotFoundError as e: + raise CheckpointReplayError(f'FileNotFoundError: {e}') from None + except OSError as e: + raise CheckpointReplayError(f'OSError: {e}') from None + + return completed_count + + +def dispatch_batch_statements( + mycli: 'MyCli', + cli_args: 'CliArgs', + statements: str, + batch_counter: int, +) -> None: + if batch_counter: + if cli_args.format == 'csv': + mycli.main_formatter.format_name = 'csv-noheader' + elif cli_args.format == 'tsv': + mycli.main_formatter.format_name = 'tsv_noheader' + elif cli_args.format == 'table': + mycli.main_formatter.format_name = 'ascii' + else: + mycli.main_formatter.format_name = 'tsv' + else: + if cli_args.format == 'csv': + mycli.main_formatter.format_name = 'csv' + elif cli_args.format == 'tsv': + mycli.main_formatter.format_name = 'tsv' + elif cli_args.format == 'table': + mycli.main_formatter.format_name = 'ascii' + else: + mycli.main_formatter.format_name = 'tsv' + + warn_confirmed: bool | None = True + if not cli_args.noninteractive and mycli.destructive_warning and is_destructive(mycli.destructive_keywords, statements): + try: + # this seems to work, even though we are reading from stdin above + sys.stdin = open('/dev/tty') + # bug: the prompt will not be visible if stdout is redirected + warn_confirmed = confirm_destructive_query(mycli.destructive_keywords, statements) + except (IOError, OSError) as e: + mycli.logger.warning('Unable to open TTY as stdin.') + raise e + if warn_confirmed: + if cli_args.throttle > 0 and batch_counter >= 1: + time.sleep(cli_args.throttle) + mycli.run_query(statements, checkpoint=cli_args.checkpoint, new_line=True) + + +def main_batch_with_progress_bar(mycli: 'MyCli', cli_args: 'CliArgs') -> int: + goal_statements = 0 + if cli_args.batch is None: + return 1 + if not sys.stdin.isatty() and cli_args.batch != '-': + click.secho('Ignoring STDIN since --batch was also given.', err=True, fg='yellow') + if os.path.exists(cli_args.batch) and not os.path.isfile(cli_args.batch): + click.secho('--progress is only compatible with a plain file.', err=True, fg='red') + return 1 + try: + completed_statement_count = replay_checkpoint_file(cli_args.batch, cli_args.checkpoint, cli_args.resume) + batch_count_h = click.open_file(cli_args.batch) + for _statement, _counter in statements_from_filehandle(batch_count_h): + goal_statements += 1 + batch_count_h.close() + batch_h = click.open_file(cli_args.batch) + batch_gen = statements_from_filehandle(batch_h) + except (OSError, FileNotFoundError): + click.secho(f'Failed to open --batch file: {cli_args.batch}', err=True, fg='red') + return 1 + except ValueError as e: + click.secho(f'Error reading --batch file: {cli_args.batch}: {e}', err=True, fg='red') + return 1 + except CheckpointReplayError as e: + name = cli_args.checkpoint.name if cli_args.checkpoint else 'None' + click.secho(f'Error replaying --checkpoint file: {name}: {e}', err=True, fg='red') + return 1 + try: + if goal_statements: + pb_style = prompt_toolkit.styles.Style.from_dict({'bar-a': 'reverse'}) + custom_formatters = [ + progress_bar_formatters.Bar(start='[', end=']', sym_a=' ', sym_b=' ', sym_c=' '), + progress_bar_formatters.Text(' '), + progress_bar_formatters.Progress(), + progress_bar_formatters.Text(' '), + progress_bar_formatters.Text('eta ', style='class:time-left'), + progress_bar_formatters.TimeLeft(), + progress_bar_formatters.Text(' ', style='class:time-left'), + ] + err_output = prompt_toolkit.output.create_output(stdout=sys.stderr, always_prefer_tty=True) + with ProgressBar(style=pb_style, formatters=custom_formatters, output=err_output) as pb: + for _pb_counter in pb(range(goal_statements)): + statement, statement_counter = next(batch_gen) + if statement_counter < completed_statement_count: + continue + dispatch_batch_statements(mycli, cli_args, statement, statement_counter) + except (ValueError, StopIteration, IOError, OSError, pymysql.err.Error) as e: + click.secho(str(e), err=True, fg='red') + return 1 + finally: + batch_h.close() + return 0 + + +def main_batch_without_progress_bar(mycli: 'MyCli', cli_args: 'CliArgs') -> int: + if cli_args.batch is None: + return 1 + if not sys.stdin.isatty() and cli_args.batch != '-': + click.secho('Ignoring STDIN since --batch was also given.', err=True, fg='red') + try: + completed_statement_count = replay_checkpoint_file(cli_args.batch, cli_args.checkpoint, cli_args.resume) + batch_h = click.open_file(cli_args.batch) + except (OSError, FileNotFoundError): + click.secho(f'Failed to open --batch file: {cli_args.batch}', err=True, fg='red') + return 1 + except CheckpointReplayError as e: + name = cli_args.checkpoint.name if cli_args.checkpoint else 'None' + click.secho(f'Error replaying --checkpoint file: {name}: {e}', err=True, fg='red') + return 1 + try: + for statement, counter in statements_from_filehandle(batch_h): + if counter < completed_statement_count: + continue + dispatch_batch_statements(mycli, cli_args, statement, counter) + except (ValueError, StopIteration, IOError, OSError, pymysql.err.Error) as e: + click.secho(str(e), err=True, fg='red') + return 1 + finally: + batch_h.close() + return 0 + + +def main_batch_from_stdin(mycli: 'MyCli', cli_args: 'CliArgs') -> int: + batch_h = click.get_text_stream('stdin') + try: + for statement, counter in statements_from_filehandle(batch_h): + dispatch_batch_statements(mycli, cli_args, statement, counter) + except (ValueError, StopIteration, IOError, OSError, pymysql.err.Error) as e: + click.secho(str(e), err=True, fg='red') + return 1 + return 0 diff --git a/mycli/main_modes/checkup.py b/mycli/main_modes/checkup.py new file mode 100644 index 00000000..c3b82a3b --- /dev/null +++ b/mycli/main_modes/checkup.py @@ -0,0 +1,156 @@ +import importlib.metadata +import json +import os +import shutil +import sys +import urllib.error +import urllib.request + +from mycli.constants import REPO_URL + +PYPI_API_BASE = 'https://pypi.org/pypi' + + +def pypi_api_fetch(fragment: str) -> dict: + fragment = fragment.lstrip('/') + url = f'{PYPI_API_BASE}/{fragment}' + try: + with urllib.request.urlopen(url, timeout=5) as response: + return json.loads(response.read().decode('utf8')) + except urllib.error.URLError: + print(f'Failed to connect to PyPi on {url}', file=sys.stderr) + return {} + + +def _dependencies_checkup() -> None: + print('\n### Key Python dependencies:\n') + for dependency in [ + 'cli_helpers', + 'click', + 'prompt_toolkit', + 'pymysql', + 'tabulate', + ]: + try: + installed_version = importlib.metadata.version(dependency) + except importlib.metadata.PackageNotFoundError: + installed_version = None + pypi_profile = pypi_api_fetch(f'/{dependency}/json') + latest_version = pypi_profile.get('info', {}).get('version', None) + print(f'{dependency} version {installed_version} (latest {latest_version})') + + +def _executables_checkup() -> None: + print('\n### External executables:\n') + for executable in [ + 'less', + 'fzf', + 'pygmentize', + ]: + if shutil.which(executable): + print(f'The "{executable}" executable was found — good!') + else: + print(f'The recommended "{executable}" executable was not found — some functionality will suffer.') + + +def _environment_checkup() -> None: + print('\n### Environment variables:\n') + for variable in [ + 'EDITOR', + 'VISUAL', + ]: + if value := os.environ.get(variable): + print(f'The ${variable} environment variable was set to "{value}" — good!') + else: + print(f'The ${variable} environment variable was not set — some functionality will suffer.') + + +def _configuration_checkup(mycli) -> None: + did_output_missing = False + did_output_unsupported = False + did_output_deprecated = False + + indent = ' ' + transitions = { + f'{indent}[main]\n{indent}default_character_set': f'{indent}[connection]\n{indent}default_character_set', + f'{indent}[main]\n{indent}ssl_mode': f'{indent}[connection]\n{indent}default_ssl_mode', + } + reverse_transitions = {v: k for k, v in transitions.items()} + + if not list(mycli.config.keys()): + print('\n### Missing file:\n') + print('The local ~/,myclirc is missing or empty.\n') + did_output_missing = True + else: + for section_name in mycli.config: + if section_name not in mycli.config_without_package_defaults: + if not did_output_missing: + print('\n### Missing in user ~/.myclirc:\n') + print(f'The entire section:\n\n{indent}[{section_name}]\n') + did_output_missing = True + continue + for item_name in mycli.config[section_name]: + transition_key = f'{indent}[{section_name}]\n{indent}{item_name}' + if transition_key in reverse_transitions: + continue + if item_name not in mycli.config_without_package_defaults[section_name]: + if not did_output_missing: + print('\n### Missing in user ~/.myclirc:\n') + print(f'The item:\n\n{indent}[{section_name}]\n{indent}{item_name} =\n') + did_output_missing = True + + for section_name in mycli.config_without_package_defaults: + if section_name not in mycli.config_without_user_options: + if not did_output_unsupported: + print('\n### Unsupported in user ~/.myclirc:\n') + did_output_unsupported = True + print(f'The entire section:\n\n{indent}[{section_name}]\n') + continue + for item_name in mycli.config_without_package_defaults[section_name]: + if section_name == 'colors' and item_name.startswith('sql.'): + # these are commented out in the package myclirc + continue + if section_name in [ + 'favorite_queries', + 'init-commands', + 'alias_dsn', + 'alias_dsn.init-commands', + ]: + # these are free-entry sections, so a comparison per item is not meaningful + continue + transition_key = f'{indent}[{section_name}]\n{indent}{item_name}' + if transition_key in transitions: + continue + if item_name not in mycli.config_without_user_options[section_name]: + if not did_output_unsupported: + print('\n### Unsupported in user ~/.myclirc:\n') + print(f'The item:\n\n{indent}[{section_name}]\n{indent}{item_name} =\n') + did_output_unsupported = True + + for section_name in mycli.config_without_package_defaults: + if section_name not in mycli.config_without_user_options: + continue + for item_name in mycli.config_without_package_defaults[section_name]: + if section_name == 'colors' and item_name.startswith('sql.'): + # these are commented out in the package myclirc + continue + transition_key = f'{indent}[{section_name}]\n{indent}{item_name}' + if transition_key in transitions: + if not did_output_deprecated: + print('\n### Deprecated in user ~/.myclirc:\n') + transition_value = transitions[transition_key] + print(f'It is recommended to transition:\n\n{transition_key}\n\nto\n\n{transition_value}\n') + did_output_deprecated = True + + if did_output_missing or did_output_unsupported or did_output_deprecated: + print(f'For more info on supported features, see the commentary and defaults at:\n\n * {REPO_URL}/blob/main/mycli/myclirc\n') + else: + print('\n### Configuration:\n') + print('User configuration all up to date!\n') + + +def main_checkup(mycli) -> None: + _dependencies_checkup() + _executables_checkup() + _environment_checkup() + _configuration_checkup(mycli) diff --git a/mycli/main_modes/execute.py b/mycli/main_modes/execute.py new file mode 100644 index 00000000..abe25562 --- /dev/null +++ b/mycli/main_modes/execute.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +import sys +from typing import TYPE_CHECKING + +import click + +if TYPE_CHECKING: + from mycli.main import CliArgs, MyCli + + +def main_execute_from_cli(mycli: 'MyCli', cli_args: 'CliArgs') -> int: + if cli_args.execute is None: + return 1 + if not sys.stdin.isatty(): + click.secho('Ignoring STDIN since --execute was also given.', err=True, fg='red') + if cli_args.batch: + click.secho('Ignoring --batch since --execute was also given.', err=True, fg='red') + try: + execute_sql = cli_args.execute + if cli_args.format == 'csv': + mycli.main_formatter.format_name = 'csv' + if execute_sql.endswith(r'\G'): + execute_sql = execute_sql[:-2] + elif cli_args.format == 'tsv': + mycli.main_formatter.format_name = 'tsv' + if execute_sql.endswith(r'\G'): + execute_sql = execute_sql[:-2] + elif cli_args.format == 'table': + mycli.main_formatter.format_name = 'ascii' + if execute_sql.endswith(r'\G'): + execute_sql = execute_sql[:-2] + else: + mycli.main_formatter.format_name = 'tsv' + + mycli.run_query(execute_sql, checkpoint=cli_args.checkpoint) + return 0 + except Exception as e: + click.secho(str(e), err=True, fg="red") + return 1 diff --git a/mycli/main_modes/list_dsn.py b/mycli/main_modes/list_dsn.py new file mode 100644 index 00000000..6a00a2c6 --- /dev/null +++ b/mycli/main_modes/list_dsn.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import click + +if TYPE_CHECKING: + from mycli.main import MyCli + + +def main_list_dsn(mycli: 'MyCli') -> int: + try: + alias_dsn = mycli.config['alias_dsn'] + except KeyError: + click.secho('Invalid DSNs found in the config file. Please check the "[alias_dsn]" section in myclirc.', err=True, fg='red') + return 1 + except Exception as e: + click.secho(str(e), err=True, fg='red') + return 1 + for alias, value in alias_dsn.items(): + if mycli.verbosity >= 1: + click.secho(f'{alias} : {value}') + else: + click.secho(alias) + return 0 diff --git a/mycli/main_modes/list_ssh_config.py b/mycli/main_modes/list_ssh_config.py new file mode 100644 index 00000000..4d3b8cfc --- /dev/null +++ b/mycli/main_modes/list_ssh_config.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import click + +from mycli.packages.ssh_utils import read_ssh_config + +if TYPE_CHECKING: + from mycli.main import CliArgs, MyCli + + +def main_list_ssh_config(mycli: 'MyCli', cli_args: 'CliArgs') -> int: + ssh_config = read_ssh_config(cli_args.ssh_config_path) + try: + host_entries = ssh_config.get_hostnames() + except KeyError: + click.secho('Error reading ssh config', err=True, fg="red") + return 1 + for host_entry in host_entries: + if mycli.verbosity >= 1: + host_config = ssh_config.lookup(host_entry) + click.secho(f"{host_entry} : {host_config.get('hostname')}") + else: + click.secho(host_entry) + return 0 diff --git a/mycli/main_modes/repl.py b/mycli/main_modes/repl.py new file mode 100644 index 00000000..43e05e5d --- /dev/null +++ b/mycli/main_modes/repl.py @@ -0,0 +1,876 @@ +from __future__ import annotations + +from dataclasses import dataclass +from datetime import datetime +import functools +from functools import partial +import html +from importlib import resources +import os +import random +import re +import subprocess +import sys +import time +import traceback +from typing import TYPE_CHECKING, Any, Generator +from xml.parsers.expat import ExpatError + +import click +import prompt_toolkit +from prompt_toolkit.application.current import get_app +from prompt_toolkit.auto_suggest import AutoSuggestFromHistory, ThreadedAutoSuggest +from prompt_toolkit.completion import DynamicCompleter +from prompt_toolkit.enums import DEFAULT_BUFFER, EditingMode +from prompt_toolkit.filters import Condition, has_focus, is_done +from prompt_toolkit.formatted_text import ( + ANSI, + HTML, + FormattedText, + to_formatted_text, + to_plain_text, +) +from prompt_toolkit.key_binding import KeyBindings +from prompt_toolkit.layout.processors import ConditionalProcessor, HighlightMatchingBracketProcessor +from prompt_toolkit.lexers import PygmentsLexer +from prompt_toolkit.output import ColorDepth +from prompt_toolkit.shortcuts import CompleteStyle, PromptSession +import pymysql +from pymysql.cursors import Cursor + +import mycli as mycli_package +from mycli.clibuffer import cli_is_multiline +from mycli.clistyle import style_factory_ptoolkit +from mycli.clitoolbar import create_toolbar_tokens_func +from mycli.constants import ( + DEFAULT_HOST, + DEFAULT_WIDTH, + ER_MUST_CHANGE_PASSWORD, + HOME_URL, + ISSUES_URL, +) +from mycli.key_bindings import mycli_bindings +from mycli.lexer import MyCliLexer +from mycli.packages import special +from mycli.packages.filepaths import dir_path_exists +from mycli.packages.hybrid_redirection import get_redirect_components, is_redirect_command +from mycli.packages.interactive_utils import confirm, confirm_destructive_query +from mycli.packages.key_binding_utils import ( + handle_clip_command, + handle_editor_command, +) +from mycli.packages.ptoolkit.history import FileHistoryWithTimestamp +from mycli.packages.special.utils import format_uptime, get_ssl_version, get_uptime, get_warning_count +from mycli.packages.sql_utils import ( + extract_new_password, + is_dropping_database, + is_mutating, + is_password_change, + is_sandbox_allowed, + is_select, + need_completion_refresh, + need_completion_reset, +) +from mycli.packages.sqlresult import SQLResult +from mycli.packages.string_utils import sanitize_terminal_title +from mycli.sqlexecute import SQLExecute +from mycli.types import Query + +if TYPE_CHECKING: + from prompt_toolkit.formatted_text import AnyFormattedText + + from mycli.main import MyCli + + +SUPPORT_INFO = f"Home: {HOME_URL}\nBug tracker: {ISSUES_URL}" +MIN_COMPLETION_TRIGGER = 1 +_PROMPT_TARGETS: dict[int, 'MyCli'] = {} + + +@dataclass(slots=True) +class ReplState: + iterations: int = 0 + mutating: bool = False + + +@Condition +def complete_while_typing_filter() -> bool: + """Whether enough characters have been typed to trigger completion. + + Written in a verbose way, with a string slice, for efficiency.""" + if MIN_COMPLETION_TRIGGER <= 1: + return True + app = get_app() + text = app.current_buffer.text.lstrip() + text_len = len(text) + if text_len < MIN_COMPLETION_TRIGGER: + return False + last_word = text[-MIN_COMPLETION_TRIGGER:] + if len(last_word) == text_len: + return text_len >= MIN_COMPLETION_TRIGGER + if text[:6].lower() in ['source', r'\.']: + # Different word characters for paths; see comment below. + # In fact, it might be nice if paths had a different threshold. + return not bool(re.search(r'[\s!-,:-@\[-^\{\}-]', last_word)) + else: + # This is "whitespace and all punctuation except underscore and backtick" + # acting as word breaks, but it would be neat if we could complete differently + # when inside a backtick, accepting all legal characters towards the trigger + # limit. We would have to parse the statement, or at least go back more + # characters, costing performance. This still works within a backtick! So + # long as there are three trailing non-punctuation characters. + return not bool(re.search(r'[\s!-/:-@\[-^\{-~]', last_word)) + + +def _create_history(mycli: 'MyCli') -> FileHistoryWithTimestamp | None: + history_file = os.path.expanduser(os.environ.get('MYCLI_HISTFILE', mycli.config['main'].get('history_file', '~/.mycli-history'))) + if dir_path_exists(history_file): + return FileHistoryWithTimestamp(history_file) + + mycli.echo( + f'Error: Unable to open the history file "{history_file}". Your query history will not be saved.', + err=True, + fg='red', + ) + return None + + +def _show_startup_banner( + mycli: 'MyCli', + sqlexecute: SQLExecute, +) -> None: + if mycli.verbosity < 0: + return + + if sqlexecute.server_info is not None: + print(sqlexecute.server_info) + print('mycli', mycli_package.__version__) + print(SUPPORT_INFO) + if random.random() <= 0.25: + print('Thanks to the sponsor —', _sponsors_picker()) + elif random.random() <= 0.5: + print('Thanks to the contributor —', _contributors_picker()) + else: + print('Tip —', _tips_picker()) + + +def set_all_external_titles(mycli: 'MyCli') -> None: + set_external_terminal_tab_title(mycli) + set_external_terminal_window_title(mycli) + set_external_multiplex_window_title(mycli) + set_external_multiplex_pane_title(mycli) + + +def set_external_terminal_tab_title(mycli: 'MyCli') -> None: + if not mycli.terminal_tab_title_format: + return + if not mycli.prompt_session: + return + if not sys.stderr.isatty(): + return + title = sanitize_terminal_title( + render_prompt_string( + mycli, + mycli.terminal_tab_title_format, + mycli.prompt_session.app.render_counter, + ) + ) + print(f'\x1b]1;{title}\a', file=sys.stderr, end='') + sys.stderr.flush() + + +def set_external_terminal_window_title(mycli: 'MyCli') -> None: + if not mycli.terminal_window_title_format: + return + if not mycli.prompt_session: + return + if not sys.stderr.isatty(): + return + title = sanitize_terminal_title( + render_prompt_string( + mycli, + mycli.terminal_window_title_format, + mycli.prompt_session.app.render_counter, + ) + ) + print(f'\x1b]2;{title}\a', file=sys.stderr, end='') + sys.stderr.flush() + + +def set_external_multiplex_window_title(mycli: 'MyCli') -> None: + if not mycli.multiplex_window_title_format: + return + if not os.getenv('TMUX'): + return + if not mycli.prompt_session: + return + title = sanitize_terminal_title( + render_prompt_string( + mycli, + mycli.multiplex_window_title_format, + mycli.prompt_session.app.render_counter, + ) + ) + try: + subprocess.run( + ['tmux', 'rename-window', title], + check=False, + stdin=subprocess.DEVNULL, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + except FileNotFoundError: + pass + + +def set_external_multiplex_pane_title(mycli: 'MyCli') -> None: + if not mycli.multiplex_pane_title_format: + return + if not os.getenv('TMUX'): + return + if not mycli.prompt_session: + return + if not sys.stderr.isatty(): + return + title = sanitize_terminal_title( + render_prompt_string( + mycli, + mycli.multiplex_pane_title_format, + mycli.prompt_session.app.render_counter, + ) + ) + print(f'\x1b]2;{title}\x1b\\', file=sys.stderr, end='') + sys.stderr.flush() + + +def get_custom_toolbar( + mycli: 'MyCli', + toolbar_format: str, +) -> FormattedText: + if not mycli.prompt_session: + return to_formatted_text('') + if not mycli.prompt_session.app: + return to_formatted_text('') + if mycli.prompt_session.app.current_buffer.text: + return mycli.last_custom_toolbar_message + mycli.last_custom_toolbar_message = render_prompt_string( + mycli, + toolbar_format, + mycli.prompt_session.app.render_counter, + ) + return mycli.last_custom_toolbar_message + + +def maybe_html_escape(string: str, is_html: bool) -> str: + if is_html: + return html.escape(string, quote=False) + return string + + +@functools.lru_cache(maxsize=256) +def render_prompt_string( + mycli: 'MyCli', + string: str, + _render_counter: int, +) -> FormattedText: + sqlexecute = mycli.sqlexecute + assert sqlexecute is not None + if mycli.login_path and mycli.login_path_as_host: + prompt_host = mycli.login_path + elif sqlexecute.host is not None: + prompt_host = sqlexecute.host + else: + prompt_host = DEFAULT_HOST + short_prompt_host, _, _ = prompt_host.partition('.') + if re.match(r'^[\d\.]+$', short_prompt_host): + short_prompt_host = prompt_host + now = datetime.now() + species_name = sqlexecute.server_info.species.name if sqlexecute.server_info and sqlexecute.server_info.species else 'MySQL' + strings = string.split('\\\\') + is_html = strings[0].startswith('\\') + strings = [x.replace('\\u', maybe_html_escape(sqlexecute.user or '(none)', is_html)) for x in strings] + strings = [x.replace('\\h', maybe_html_escape(prompt_host or '(none)', is_html)) for x in strings] + strings = [x.replace('\\H', maybe_html_escape(short_prompt_host or '(none)', is_html)) for x in strings] + strings = [x.replace('\\d', maybe_html_escape(sqlexecute.dbname or '(none)', is_html)) for x in strings] + strings = [x.replace('\\t', maybe_html_escape(species_name, is_html)) for x in strings] + strings = [x.replace('\\n', '\n') for x in strings] + strings = [x.replace('\\D', maybe_html_escape(now.strftime('%a %b %d %H:%M:%S %Y'), is_html)) for x in strings] + strings = [x.replace('\\m', maybe_html_escape(now.strftime('%M'), is_html)) for x in strings] + strings = [x.replace('\\P', maybe_html_escape(now.strftime('%p'), is_html)) for x in strings] + strings = [x.replace('\\R', maybe_html_escape(now.strftime('%H'), is_html)) for x in strings] + strings = [x.replace('\\r', maybe_html_escape(now.strftime('%I'), is_html)) for x in strings] + strings = [x.replace('\\s', maybe_html_escape(now.strftime('%S'), is_html)) for x in strings] + strings = [x.replace('\\p', maybe_html_escape(str(sqlexecute.port), is_html)) for x in strings] + strings = [ + x.replace('\\j', maybe_html_escape(os.path.basename(sqlexecute.socket or '(none)').replace('\\', '/'), is_html)) for x in strings + ] + strings = [x.replace('\\J', maybe_html_escape((sqlexecute.socket or '(none)').replace('\\', '/'), is_html)) for x in strings] + strings = [ + x.replace('\\k', maybe_html_escape(os.path.basename(sqlexecute.socket or str(sqlexecute.port)).replace('\\', '/'), is_html)) + for x in strings + ] + strings = [ + x.replace('\\K', maybe_html_escape((sqlexecute.socket or str(sqlexecute.port)).replace('\\', '/'), is_html)) for x in strings + ] + strings = [x.replace('\\A', maybe_html_escape(mycli.dsn_alias or '(none)', is_html)) for x in strings] + strings = [x.replace('\\_', ' ') for x in strings] + + checker_string = ' '.join(strings) + if hasattr(sqlexecute, 'conn') and sqlexecute.conn is not None: + if '\\y' in checker_string: + with sqlexecute.conn.cursor() as cur: + strings = [x.replace('\\y', maybe_html_escape(str(get_uptime(cur)) or '(none)', is_html)) for x in strings] + if '\\Y' in checker_string: + with sqlexecute.conn.cursor() as cur: + strings = [x.replace('\\Y', maybe_html_escape(format_uptime(str(get_uptime(cur))) or '(none)', is_html)) for x in strings] + else: + strings = [x.replace('\\y', '(none)') for x in strings] + strings = [x.replace('\\Y', '(none)') for x in strings] + + if hasattr(sqlexecute, 'conn') and sqlexecute.conn is not None: + if '\\T' in checker_string: + with sqlexecute.conn.cursor() as cur: + strings = [x.replace('\\T', maybe_html_escape(get_ssl_version(cur) or '(none)', is_html)) for x in strings] + else: + strings = [x.replace('\\T', '(none)') for x in strings] + + if hasattr(sqlexecute, 'conn') and sqlexecute.conn is not None: + if '\\w' in checker_string: + with sqlexecute.conn.cursor() as cur: + strings = [x.replace('\\w', maybe_html_escape(str(get_warning_count(cur) or '(none)'), is_html)) for x in strings] + else: + strings = [x.replace('\\w', '(none)') for x in strings] + if hasattr(sqlexecute, 'conn') and sqlexecute.conn is not None: + if '\\W' in checker_string: + with sqlexecute.conn.cursor() as cur: + strings = [x.replace('\\W', maybe_html_escape(str(get_warning_count(cur) or ''), is_html)) for x in strings] + else: + strings = [x.replace('\\W', '') for x in strings] + + if is_html: + strings[0] = strings[0].removeprefix('\\') + strings[-1] = strings[-1].removesuffix('\\') + elif '\\x1b' in checker_string: + strings = [x.replace('\\x1b', '\x1b') for x in strings] + + strings = [re.sub(r'\\(.)', r'(unknown prompt format string: \\\1)', x) for x in strings] + + string = '\\'.join(strings) + + if is_html: + try: + formatted_string = to_formatted_text(HTML(string)) + except (ExpatError, ValueError): + formatted_string = to_formatted_text(HTML('(cannot parse HTML prompt string)')) + else: + formatted_string = to_formatted_text(ANSI(string)) + + return formatted_string + + +def _get_prompt_message( + mycli: 'MyCli', + app: prompt_toolkit.application.application.Application, +) -> FormattedText: + if app.current_buffer.text: + return mycli.last_prompt_message + + prompt = render_prompt_string(mycli, mycli.prompt_format, app.render_counter) + prompt_plain = to_plain_text(prompt) + if mycli.prompt_format == mycli.default_prompt and len(prompt_plain) > mycli.max_len_prompt: + prompt = render_prompt_string(mycli, mycli.default_prompt_splitln, app.render_counter) + prompt_plain = to_plain_text(prompt) + mycli.prompt_lines = prompt_plain.count('\n') + 1 + if not mycli.prompt_lines: + mycli.prompt_lines = prompt_plain.count('\n') + 1 + + mycli.last_prompt_message = prompt + return mycli.last_prompt_message + + +def _get_continuation( + mycli: 'MyCli', + width: int, + _two: int, + _three: int, +) -> AnyFormattedText: + if mycli.multiline_continuation_char == '': + continuation = '' + elif mycli.multiline_continuation_char: + left_padding = width - len(mycli.multiline_continuation_char) + continuation = ' ' * max((left_padding - 1), 0) + mycli.multiline_continuation_char + ' ' + else: + continuation = ' ' + return [('class:continuation', continuation)] + + +def _output_results( + mycli: 'MyCli', + state: ReplState, + results: Generator[SQLResult], + start: float, +) -> None: + sqlexecute = mycli.sqlexecute + assert sqlexecute is not None + + result_count = 0 + watch_count = 0 + for result in results: + mycli.logger.debug('preamble: %r', result.preamble) + mycli.logger.debug('header: %r', result.header) + mycli.logger.debug('rows: %r', result.rows) + mycli.logger.debug('status: %r', result.status) + mycli.logger.debug('command: %r', result.command) + threshold = 1000 + if result.command is not None and result.command['name'] == 'watch': + if watch_count > 0: + try: + watch_seconds = float(result.command['seconds']) + start += watch_seconds + except ValueError as e: + mycli.echo(f'Invalid watch sleep time provided ({e}).', err=True, fg='red') + sys.exit(1) + else: + watch_count += 1 + + if is_select(result.status_plain) and isinstance(result.rows, Cursor) and result.rows.rowcount > threshold: + mycli.echo( + f'The result set has more than {threshold} rows.', + fg='red', + ) + if not confirm('Do you want to continue?'): + mycli.echo('Aborted!', err=True, fg='red') + break + + if mycli.auto_vertical_output: + if mycli.prompt_session is not None: + max_width = mycli.prompt_session.output.get_size().columns + else: + max_width = DEFAULT_WIDTH + else: + max_width = None + + formatted = mycli.format_sqlresult( + result, + is_expanded=special.is_expanded_output(), + is_redirected=special.is_redirected(), + null_string=mycli.null_string, + numeric_alignment=mycli.numeric_alignment, + binary_display=mycli.binary_display, + max_width=max_width, + ) + + duration = time.time() - start + try: + if result_count > 0: + mycli.echo('') + try: + mycli.output(formatted, result) + except KeyboardInterrupt: + pass + if mycli.beep_after_seconds > 0 and duration >= mycli.beep_after_seconds: + assert mycli.prompt_session is not None + mycli.prompt_session.output.bell() + if special.is_timing_enabled(): + mycli.output_timing(f'Time: {duration:0.03f}s') + except KeyboardInterrupt: + pass + + start = time.time() + result_count += 1 + state.mutating = state.mutating or is_mutating(result.status_plain) + + if special.is_show_warnings_enabled() and isinstance(result.rows, Cursor) and result.rows.warning_count > 0: + warnings = sqlexecute.run('SHOW WARNINGS') + warnings_duration = time.time() - start + saw_warning = False + for warning in warnings: + saw_warning = True + formatted = mycli.format_sqlresult( + warning, + is_expanded=special.is_expanded_output(), + is_redirected=special.is_redirected(), + null_string=mycli.null_string, + numeric_alignment=mycli.numeric_alignment, + binary_display=mycli.binary_display, + max_width=max_width, + is_warnings_style=True, + ) + mycli.echo('') + mycli.output(formatted, warning, is_warnings_style=True) + + if saw_warning and special.is_timing_enabled(): + mycli.output_timing(f'Time: {warnings_duration:0.03f}s', is_warnings_style=True) + + +def _keepalive_hook( + mycli: 'MyCli', + _context: Any, +) -> None: + if mycli.keepalive_ticks is None: + return + if mycli.keepalive_ticks < 1: + return + + mycli._keepalive_counter += 1 + if mycli._keepalive_counter > mycli.keepalive_ticks: + mycli._keepalive_counter = 0 + mycli.logger.debug('keepalive ping') + try: + assert mycli.sqlexecute is not None + assert mycli.sqlexecute.conn is not None + mycli.sqlexecute.conn.ping(reconnect=False) + except Exception as e: + mycli.logger.debug('keepalive ping error %r', e) + + +def _build_prompt_session( + mycli: 'MyCli', + state: ReplState, + history: FileHistoryWithTimestamp | None, + key_bindings: KeyBindings, +) -> None: + if mycli.toolbar_format.lower() == 'none': + get_toolbar_tokens = None + else: + get_toolbar_tokens = create_toolbar_tokens_func( + mycli, + lambda: state.iterations == 0, + mycli.toolbar_format, + partial(get_custom_toolbar, mycli), + ) + + if mycli.wider_completion_menu: + complete_style = CompleteStyle.MULTI_COLUMN + else: + complete_style = CompleteStyle.COLUMN + + with mycli._completer_lock: + if mycli.key_bindings == 'vi': + editing_mode = EditingMode.VI + else: + editing_mode = EditingMode.EMACS + + mycli.prompt_session = PromptSession( + color_depth=ColorDepth.DEPTH_24_BIT if 'truecolor' in os.getenv('COLORTERM', '').lower() else None, + lexer=PygmentsLexer(MyCliLexer), + reserve_space_for_menu=mycli.get_reserved_space(), + prompt_continuation=lambda width, two, three: _get_continuation(mycli, width, two, three), + bottom_toolbar=get_toolbar_tokens, + complete_style=complete_style, + input_processors=[ + ConditionalProcessor( + processor=HighlightMatchingBracketProcessor(chars='[](){}'), + filter=has_focus(DEFAULT_BUFFER) & ~is_done, + ) + ], + tempfile_suffix='.sql', + completer=DynamicCompleter(lambda: mycli.completer), + complete_in_thread=True, + history=history, + auto_suggest=ThreadedAutoSuggest(AutoSuggestFromHistory()), + complete_while_typing=complete_while_typing_filter, + multiline=cli_is_multiline(mycli), + style=style_factory_ptoolkit(mycli.syntax_style, mycli.cli_style), + include_default_pygments_style=False, + key_bindings=key_bindings, + enable_open_in_editor=True, + enable_system_prompt=True, + enable_suspend=True, + editing_mode=editing_mode, + search_ignore_case=True, + ) + + if mycli.key_bindings == 'vi': + mycli.prompt_session.app.ttimeoutlen = mycli.vi_ttimeoutlen + else: + mycli.prompt_session.app.ttimeoutlen = mycli.emacs_ttimeoutlen + + +def _one_iteration( + mycli: 'MyCli', + state: ReplState, + text: str | None = None, +) -> None: + sqlexecute = mycli.sqlexecute + assert sqlexecute is not None + + inputhook = partial(_keepalive_hook, mycli) if mycli.keepalive_ticks and mycli.keepalive_ticks >= 1 else None + + if text is None: + try: + assert mycli.prompt_session is not None + loaded_message_fn = partial(_get_prompt_message, mycli, mycli.prompt_session.app) + text = mycli.prompt_session.prompt( + inputhook=inputhook, + message=loaded_message_fn, + ) + except KeyboardInterrupt: + return + + special.set_expanded_output(False) + special.set_forced_horizontal_output(False) + + try: + text = handle_editor_command( + mycli, + text, + inputhook, + loaded_message_fn, + ) + except RuntimeError as e: + mycli.logger.error('sql: %r, error: %r', text, e) + mycli.logger.error('traceback: %r', traceback.format_exc()) + mycli.echo(str(e), err=True, fg='red') + return + + try: + if handle_clip_command(mycli, text): + return + except RuntimeError as e: + mycli.logger.error('sql: %r, error: %r', text, e) + mycli.logger.error('traceback: %r', traceback.format_exc()) + mycli.echo(str(e), err=True, fg='red') + return + + while special.is_llm_command(text): + start = time.time() + try: + assert sqlexecute.conn is not None + cur = sqlexecute.conn.cursor() + context, sql, duration = special.handle_llm( + text, + cur, + sqlexecute.dbname or '', + mycli.llm_prompt_field_truncate, + mycli.llm_prompt_section_truncate, + ) + if context: + click.echo('LLM Response:') + click.echo(context) + click.echo('---') + if special.is_timing_enabled(): + mycli.output_timing(f'Time: {duration:0.03f}s') + assert mycli.prompt_session is not None + text = mycli.prompt_session.prompt( + default=sql or '', + inputhook=inputhook, + message=loaded_message_fn, + ) + except KeyboardInterrupt: + return + except special.FinishIteration as e: + if e.results: + _output_results(mycli, state, e.results, start) + return + except RuntimeError as e: + mycli.logger.error('sql: %r, error: %r', text, e) + mycli.logger.error('traceback: %r', traceback.format_exc()) + mycli.echo(str(e), err=True, fg='red') + return + + text = text.strip() + if not text: + return + + if is_redirect_command(text): + sql_part, command_part, file_operator_part, file_part = get_redirect_components(text) + text = sql_part or '' + try: + special.set_redirect(command_part, file_operator_part, file_part) + except (FileNotFoundError, OSError, RuntimeError) as e: + mycli.logger.error('sql: %r, error: %r', text, e) + mycli.logger.error('traceback: %r', traceback.format_exc()) + mycli.echo(str(e), err=True, fg='red') + return + + if mycli.sandbox_mode and not is_sandbox_allowed(text): + mycli.echo( + "ERROR 1820: You must reset your password using ALTER USER or SET PASSWORD before executing this statement.", + err=True, + fg='red', + ) + return + + if mycli.destructive_warning: + destroy = confirm_destructive_query(mycli.destructive_keywords, text) + if destroy is None: + pass + elif destroy is True: + mycli.echo('Your call!') + else: + mycli.echo('Wise choice!') + return + + successful = False + try: + mycli.logger.debug('sql: %r', text) + special.write_tee(mycli.last_prompt_message, nl=False) + special.write_tee(text) + mycli.log_query(text) + + start = time.time() + results = sqlexecute.run(text) + mycli.main_formatter.query = text + mycli.redirect_formatter.query = text + successful = True + _output_results(mycli, state, results, start) + special.unset_once_if_written(mycli.post_redirect_command) + special.flush_pipe_once_if_written(mycli.post_redirect_command) + except pymysql.err.InterfaceError: + if not mycli.reconnect(): + return + _one_iteration(mycli, state, text) + return + except EOFError as e: + raise e + except KeyboardInterrupt: + connection_id_to_kill = sqlexecute.connection_id or 0 + if connection_id_to_kill > 0: + mycli.logger.debug('connection id to kill: %r', connection_id_to_kill) + try: + sqlexecute.connect() + for kill_result in sqlexecute.run(f'kill {connection_id_to_kill}'): + status_str = str(kill_result.status_plain).lower() + if status_str.find('ok') > -1: + mycli.logger.debug('cancelled query, connection id: %r, sql: %r', connection_id_to_kill, text) + mycli.echo(f'Cancelled query id: {connection_id_to_kill}', err=True, fg='blue') + else: + mycli.logger.debug( + 'Failed to confirm query cancellation, connection id: %r, sql: %r', + connection_id_to_kill, + text, + ) + mycli.echo(f'Failed to confirm query cancellation, id: {connection_id_to_kill}', err=True, fg='red') + except Exception as e2: + mycli.echo(f'Encountered error while cancelling query: {e2}', err=True, fg='red') + else: + mycli.logger.debug('Did not get a connection id, skip cancelling query') + mycli.echo('Did not get a connection id, skip cancelling query', err=True, fg='red') + except NotImplementedError: + mycli.echo('Not Yet Implemented.', fg='yellow') + except pymysql.OperationalError as e1: + mycli.logger.debug('Exception: %r', e1) + if e1.args[0] == ER_MUST_CHANGE_PASSWORD: + mycli.sandbox_mode = True + mycli.echo( + "ERROR 1820: You must reset your password using ALTER USER or SET PASSWORD before executing this statement.", + err=True, + fg='red', + ) + elif e1.args[0] in (2003, 2006, 2013): + if not mycli.reconnect(): + return + _one_iteration(mycli, state, text) + return + else: + mycli.logger.error('sql: %r, error: %r', text, e1) + mycli.logger.error('traceback: %r', traceback.format_exc()) + mycli.echo(str(e1), err=True, fg='red') + except Exception as e: + mycli.logger.error('sql: %r, error: %r', text, e) + mycli.logger.error('traceback: %r', traceback.format_exc()) + mycli.echo(str(e), err=True, fg='red') + else: + if mycli.sandbox_mode and is_password_change(text): + new_password = extract_new_password(text) + if new_password is not None: + sqlexecute.password = new_password + try: + sqlexecute.connect() + mycli.sandbox_mode = False + mycli.echo("Password changed successfully. Reconnected.", err=True, fg='green') + mycli.refresh_completions() + except Exception as e: + mycli.sandbox_mode = False + mycli.echo( + f"Password changed but reconnection failed: {e}\nPlease restart mycli with your new password.", + err=True, + fg='yellow', + ) + + if is_dropping_database(text, sqlexecute.dbname): + sqlexecute.dbname = None + sqlexecute.connect() + + if need_completion_refresh(text): + mycli.refresh_completions(reset=need_completion_reset(text)) + finally: + if mycli.logfile is False: + mycli.echo('Warning: This query was not logged.', err=True, fg='red') + + query = Query(text, successful, state.mutating) + mycli.query_history.append(query) + + +def _contributors_picker() -> str: + lines: str = "" + + try: + with resources.files(mycli_package).joinpath("AUTHORS").open('r') as f: + lines += f.read() + except FileNotFoundError: + pass + + contents = [] + for line in lines.split("\n"): + if m := re.match(r"^ *\* (.*)", line): + contents.append(m.group(1)) + return random.choice(contents) if contents else 'our contributors' + + +def _sponsors_picker() -> str: + lines: str = "" + + try: + with resources.files(mycli_package).joinpath("SPONSORS").open('r') as f: + lines += f.read() + except FileNotFoundError: + pass + + contents = [] + for line in lines.split("\n"): + if m := re.match(r"^ *\* (.*)", line): + contents.append(m.group(1)) + return random.choice(contents) if contents else 'our sponsors' + + +def _tips_picker() -> str: + tips = [] + + try: + with resources.files(mycli_package).joinpath('TIPS').open('r') as f: + for line in f: + if line.startswith("#"): + continue + if tip := line.strip(): + tips.append(tip) + except FileNotFoundError: + pass + + return random.choice(tips) if tips else r'\? or "help" for help!' + + +def main_repl(mycli: 'MyCli') -> None: + sqlexecute = mycli.sqlexecute + assert sqlexecute is not None + state = ReplState() + + mycli.configure_pager() + if mycli.smart_completion and not mycli.sandbox_mode: + mycli.refresh_completions() + + history = _create_history(mycli) + key_bindings = mycli_bindings(mycli) + _show_startup_banner(mycli, sqlexecute) + _build_prompt_session(mycli, state, history, key_bindings) + set_all_external_titles(mycli) + + try: + while True: + _one_iteration(mycli, state) + state.iterations += 1 + except EOFError: + special.close_tee() + if mycli.verbosity >= 0: + mycli.echo('Goodbye!') diff --git a/mycli/myclirc b/mycli/myclirc index cd58dfe2..76663572 100644 --- a/mycli/myclirc +++ b/mycli/myclirc @@ -1,10 +1,29 @@ # vi: ft=dosini [main] +# Enable or disable the automatic displaying of warnings ("SHOW WARNINGS") +# after executing a SQL statement when applicable. +show_warnings = False + # Enables context sensitive auto-completion. If this is disabled the all # possible completions will be listed. smart_completion = True +# Minimum characters typed before offering completion suggestions. +# Suggestion: 3. +min_completion_trigger = 1 + +# Prefetch completion metadata for schemas in the background after launch. +# Possible values: +# always = prefetch all schemas (default) +# never = do not prefetch any schemas +# listed = prefetch only the schemas named in prefetch_schemas_list +prefetch_schemas_mode = always + +# Comma-separated list of schemas to prefetch when +# prefetch_schemas_mode = listed. Ignored in other modes. +prefetch_schemas_list = + # Multi-line mode allows breaking up the sql statements into multiple lines. If # this is set to True, then the end of the statements must have a semi-colon. # If this is set to False then sql statements can't be split into multiple @@ -16,6 +35,14 @@ multi_line = False # or "shutdown". destructive_warning = True +# Queries starting with these keywords will activate the destructive warning. +# UPDATE will not activate the warning if the statement includes a WHERE +# clause. +destructive_keywords = DROP SHUTDOWN DELETE TRUNCATE ALTER UPDATE + +# interactive query history location. +history_file = ~/.mycli-history + # log_file location. log_file = ~/.mycli.log @@ -27,23 +54,51 @@ log_level = INFO # line below. # audit_log = ~/.mycli-audit.log -# Timing of sql statements and table rendering. +# Timing of SQL statements and table rendering, or LLM commands. timing = True +# Show the full SQL when running a favorite query. Set to False to hide. +show_favorite_query = True + # Beep after long-running queries are completed; 0 to disable. beep_after_seconds = 0 -# Table format. Possible values: ascii, double, github, -# psql, plain, simple, grid, fancy_grid, pipe, orgtbl, rst, mediawiki, html, -# latex, latex_booktabs, textile, moinmoin, jira, vertical, tsv, csv. -# Recommended: ascii +# Table format. Possible values: ascii, ascii_escaped, csv, csv-noheader, +# csv-tab, csv-tab-noheader, double, fancy_grid, github, grid, html, jira, +# jsonl, jsonl_escaped, latex, latex_booktabs, mediawiki, minimal, moinmoin, +# mysql, mysql_unicode, mysql_heavy, orgtbl, pipe, plain, psql, psql_unicode, +# rst, simple, sql-insert, sql-update, sql-update-1, sql-update-2, textile, +# tsv, tsv_noheader, vertical. +# Recommended: mysql_unicode. table_format = ascii +# Redirected otuput format +# Recommended: csv. +redirect_format = csv + +# How to display the missing value (ie NULL). Only certain table formats +# support configuring the missing value. CSV for example always uses the +# empty string, and JSON formats use native nulls. +null_string = + +# How to align numeric data in tabular output: right or left. +numeric_alignment = right + +# How to display binary values in tabular output: "hex", or "utf8". "utf8" +# means attempt to render valid UTF-8 sequences as strings, then fall back +# to hex rendering if not possible. +binary_display = hex + +# A command to run after a successful output redirect, with {} to be replaced +# with the escaped filename. Mac example: echo {} | pbcopy. Escaping is not +# reliable/safe on Windows. +post_redirect_command = + # Syntax coloring style. Possible values (many support the "-dark" suffix): # manni, igor, xcode, vim, autumn, vs, rrt, native, perldoc, borland, tango, emacs, # friendly, monokai, paraiso, colorful, murphy, bw, pastie, paraiso, trac, default, # fruity. -# Screenshots at http://mycli.net/syntax +# Screenshots at https://mycli.net/syntax # Can be further modified in [colors] syntax_style = default @@ -56,24 +111,78 @@ key_bindings = emacs wider_completion_menu = False # MySQL prompt -# \D - The full current date -# \d - Database name -# \h - Hostname of the server -# \m - Minutes of the current time -# \n - Newline -# \P - AM/PM -# \p - Port -# \R - The current time, in 24-hour military time (0-23) -# \r - The current time, standard 12-hour time (1-12) -# \s - Seconds of the current time -# \t - Product type (Percona, MySQL, MariaDB, TiDB) -# \A - DSN alias name (from the [alias_dsn] section) -# \u - Username -# \x1b[...m - insert ANSI escape sequence +# * \D - full current date, e.g. Sat Feb 14 15:55:48 2026 +# * \R - current hour in 24-hour time (00–23) +# * \r - current hour in 12-hour time (01–12) +# * \m - minutes of the current time +# * \s - seconds of the current time +# * \P - AM/PM +# * \d - selected database/schema +# * \h - hostname of the server +# * \H - shortened hostname of the server +# * \p - connection port +# * \j - connection socket basename +# * \J - full connection socket path +# * \k - connection socket basename OR the port +# * \K - full connection socket path OR the port +# * \T - connection SSL/TLS version +# * \t - database vendor (Percona, MySQL, MariaDB, TiDB) +# * \u - username +# * \w - number of warnings, or "(none)" (requires frequent trips to the server) +# * \W - number of warnings, or the empty string (requires frequent trips to the server) +# * \y - uptime in seconds (requires frequent trips to the server) +# * \Y - uptime in words (requires frequent trips to the server) +# * \A - DSN alias +# * \n - a newline +# * \_ - a space +# * \\ - a literal backslash +# * \x1b[...m - an ANSI escape sequence (can style with color or attributes) +# ANSI color example: prompt = '\x1b[31mroot\x1b[0m@localhost:\d> ' +# * \ - a leading sequence indicating that the rest of the prompt be styled like HTML. +# See https://python-prompt-toolkit.readthedocs.io/en/stable/pages/printing_text.html#html . +# Characters such as "&" or literal "<" and ">" must be HTML-escaped in this mode. +# HTML styles cannot be combined with ANSI sequences. HTML mode takes precedence. +# HTML color example: prompt = '\root@localhost:\d> ' +# prompt = '\t \u@\h:\d> ' prompt_continuation = '->' -# Skip intro info on startup and outro info on exit +# Use the same prompt format strings to construct a status line in the toolbar, +# where \B in the first position refers to the default toolbar showing keystrokes +# and state. Example: +# +# toolbar = '\B\d \D' +# +# If \B is included, the additional content will begin on the next line. More +# lines can be added with \n. If \B is not included, the customized toolbar +# can be a single line. An empty value is the same as the default "\B". The +# special literal value "None" will suppress the toolbar from appearing. +toolbar = '' + +# Use the same prompt format strings to construct a terminal tab title. +# The original XTerm docs call this title the "window title", but it now +# probably refers to a terminal tab. This title is only updated as frequently +# as the database is changed. +terminal_tab_title = '' + +# Use the same prompt format strings to construct a terminal window title. +# The original XTerm docs call this title the "icon title", but it now +# probably refers to a terminal window which contains tabs. This title is +# only updated as frequently as the database is changed. +terminal_window_title = '' + +# Use the same prompt format strings to construct a window title in a terminal +# multiplexer. Currently only tmux is supported. This title is only updated +# as frequently as the database is changed. +multiplex_window_title = '' + +# Use the same prompt format strings to construct a pane title in a terminal +# multiplexer. Currently only tmux is supported. This title is only updated +# as frequently as the database is changed. +multiplex_pane_title = '' + +# Skip intro info on startup and outro info on exit, and generally reduce +# feedback. This is equivalent to giving --quiet at the command line. less_chatty = False # Use alias from --login-path instead of host name in prompt @@ -92,8 +201,111 @@ enable_pager = True # Choose a specific pager pager = 'less' -# Custom colors for the completion menu, toolbar, etc. +# whether to show verbose warnings about the transition away from reading my.cnf +my_cnf_transition_done = False + +# Whether to store and retrieve passwords from the system keyring. +# See the documentation for https://pypi.org/project/keyring/ for your OS. +# Note that the hostname is considered to be different if short or qualified. +# This can be overridden with --use-keyring= at the CLI. +# A password can be reset with --use-keyring=reset at the CLI. +use_keyring = False + +[search] + +# Whether to apply syntax highlighting to the preview window in fuzzy history +# search. There is a small performance penalty to enabling this. The "pygmentize" +# CLI tool must also be available. The syntax style from the "syntax_style" +# option will be respected, though additional customizations from [colors] will +# not be applied. +highlight_preview = False + +[connection] + +# character set for connections without --character-set being set +default_character_set = utf8mb4 + +# whether to enable LOAD DATA LOCAL INFILE for connections without --local-infile being set +default_local_infile = False + +# How often to send periodic background pings to the server when input is idle. Ticks are +# roughly in seconds, but may be faster. Set to zero to disable. Suggestion: 300. +default_keepalive_ticks = 0 + +# Sets the desired behavior for handling secure connections to the database server. +# Possible values: +# auto = SSL is preferred for TCP/IP connections. Will attempt to connect via SSL, but will fall +# back to cleartext as needed. Will not attempt to connect with SSL over local sockets. +# on = SSL is required. Will attempt to connect via SSL even on a local socket, and will fail if +# a secure connection is not established. +# off = do not use SSL. Will fail if the server requires a secure connection. +default_ssl_mode = auto + +# SSL CA file for connections without --ssl-ca being set +default_ssl_ca = + +# SSL CA directory for connections without --ssl-capath being set +default_ssl_capath = + +# SSL X509 cert path for connections without --ssl-cert being set +default_ssl_cert = + +# SSL X509 key for connections without --ssl-key being set +default_ssl_key = + +# SSL cipher to use for connections without --ssl-cipher being set +default_ssl_cipher = + +# whether to verify server's "Common Name" in its cert, for connections without +# --ssl-verify-server-cert being set +default_ssl_verify_server_cert = False + +[llm] + +# If set to a positive integer, truncate text/binary fields to that width +# in bytes when sending sample data, to conserve tokens. Suggestion: 1024. +prompt_field_truncate = None + +# If set to a positive integer, attempt to truncate various sections of LLM +# prompt input to that number in bytes, to conserve tokens. Suggestion: +# 1000000. +prompt_section_truncate = None + +[keys] + +# possible values: exit, none +control_d = exit + +# possible values: auto, fzf, reverse_isearch +control_r = auto + +# comma-separated list: toolkit_default, summon, advancing_summon, prefixing_summon, advance, cancel +# +# * toolkit_default - ignore other behaviors and use prompt_toolkit's default bindings +# * summon - when completions are not visible, summon them +# * advancing_summon - when completions are not visible, summon them _and_ advance in the list +# * prefixing_summon - when completions are not visible, summon them _and_ insert the common prefix +# * advance - when completions are visible, advance in the list +# * cancel - when completions are visible, toggle the list off +control_space = summon, advance + +# comma-separated list: toolkit_default, summon, advancing_summon, prefixing_summon, advance, cancel +tab = advancing_summon, advance + +# How long to wait for an Escape key sequence in vi mode. +# 0.5 seconds is the prompt_toolkit default, but vi users may find that too long. +# Shorter values mean that "Escape" alone is recognized more quickly. +vi_ttimeoutlen = 0.1 + +# How long to wait for an Escape key sequence in Emacs mode. +emacs_ttimeoutlen = 0.5 + +# Custom colors for the completion menu, toolbar, etc, with actual support +# depending on the terminal, and the property being set. +# Colors: #ffffff, bg:#ffffff, border:#ffffff. +# Attributes: (no)blink, bold, dim, hidden, inherit, italic, reverse, strike, underline. [colors] +# Completion menus completion-menu.completion.current = 'bg:#ffffff #000000' completion-menu.completion = 'bg:#008888 #ffffff' completion-menu.meta.completion.current = 'bg:#44aaaa #000000' @@ -101,27 +313,47 @@ completion-menu.meta.completion = 'bg:#448888 #ffffff' completion-menu.multi-column-meta = 'bg:#aaffff #000000' scrollbar.arrow = 'bg:#003333' scrollbar = 'bg:#00aaaa' + +# The prompt +prompt = '' +continuation = '' + +# Colored table output (query results) +output.table-separator = "" +output.header = "#00ff5f bold" +output.odd-row = "" +output.even-row = "" +output.null = "#808080" +output.status = "" +output.status.warning-count = "" +output.timing = "" + +# Selected text (native selection; currently unused) selected = '#ffffff bg:#6666aa' + +# Search matches (for reverse i-search, not fuzzy search) search = '#ffffff bg:#4444aa' search.current = '#ffffff bg:#44aa44' + +# UI elements: bottom toolbar bottom-toolbar = 'bg:#222222 #aaaaaa' bottom-toolbar.off = 'bg:#222222 #888888' bottom-toolbar.on = 'bg:#222222 #ffffff' +bottom-toolbar.transaction.valid = 'bg:#222222 #00ff5f bold' +bottom-toolbar.transaction.failed = 'bg:#222222 #ff005f bold' + +# UI elements: other toolbars (currently unused) search-toolbar = 'noinherit bold' search-toolbar.text = 'nobold' system-toolbar = 'noinherit bold' arg-toolbar = 'noinherit bold' arg-toolbar.text = 'nobold' -bottom-toolbar.transaction.valid = 'bg:#222222 #00ff5f bold' -bottom-toolbar.transaction.failed = 'bg:#222222 #ff005f bold' -# style classes for colored table output -output.header = "#00ff5f bold" -output.odd-row = "" -output.even-row = "" -output.null = "#808080" +# SQL enhacements: matching brackets +matching-bracket.cursor = '#ff8888 bg:#880000' +matching-bracket.other = '#000000 bg:#aacccc' -# SQL syntax highlighting overrides +# SQL syntax highlighting overrides: normally defined by main.syntax_style # sql.comment = 'italic #408080' # sql.comment.multi-line = '' # sql.comment.single-line = '' @@ -151,9 +383,22 @@ output.null = "#808080" # sql.whitespace = '' # Favorite queries. +# You can add your favorite queries here. They will be available in the +# REPL when you type `\f` or `\f `. [favorite_queries] +# example = "SELECT * FROM example_table WHERE id = 1" + +# Initial commands to execute when connecting to any database. +[init-commands] +# read_only = "SET SESSION TRANSACTION READ ONLY" + # Use the -d option to reference a DSN. # Special characters in passwords and other strings can be escaped with URL encoding. [alias_dsn] # example_dsn = mysql://[user[:password]@][host][:port][/dbname] + +# Initial commands to execute when connecting to a DSN alias. +[alias_dsn.init-commands] +# Define one or more SQL statements per alias (semicolon-separated). +# example_dsn = "SET sql_select_limit=1000; SET time_zone='+00:00'" diff --git a/mycli/output.py b/mycli/output.py new file mode 100644 index 00000000..eee1021a --- /dev/null +++ b/mycli/output.py @@ -0,0 +1,291 @@ +from __future__ import annotations + +from datetime import datetime +from decimal import Decimal +from io import TextIOWrapper +import itertools +import os +import shutil +from typing import Any, Generator, Literal, Protocol + +from cli_helpers.tabular_output import TabularOutputFormatter, preprocessors +from cli_helpers.tabular_output.output_formatter import MISSING_VALUE as DEFAULT_MISSING_VALUE +from cli_helpers.utils import strip_ansi +import click +from configobj import ConfigObj +import prompt_toolkit +from prompt_toolkit.formatted_text import ( + ANSI, + HTML, + AnyFormattedText, + FormattedText, + to_formatted_text, + to_plain_text, +) +from prompt_toolkit.shortcuts import PromptSession +from prompt_toolkit.styles.style import _MergedStyle +from pygments.style import Style as PygmentsStyle +from pymysql.cursors import Cursor + +from mycli.compat import WIN +from mycli.constants import DEFAULT_HEIGHT, DEFAULT_WIDTH +import mycli.main_modes.repl as repl_mode +from mycli.packages import special +from mycli.packages.sqlresult import SQLResult +from mycli.packages.tabular_output import sql_format +from mycli.sqlexecute import FIELD_TYPES + + +class MyCliState(Protocol): + # Provided by AppStateMixin. + def read_my_cnf(self, cnf: ConfigObj, keys: list[str]) -> dict[str, Any]: ... + + # Provided by OutputMixin itself; declared so cross-method calls type-check. + def log_output(self, output: str | AnyFormattedText) -> None: ... + def get_output_margin(self, status: str | None = None) -> int: ... + def get_reserved_space(self) -> int: ... + + +class OutputMixin(MyCliState): + prompt_lines: int + multiline_continuation_char: str + multiplex_pane_title_format: str + multiplex_window_title_format: str + terminal_tab_title_format: str + terminal_window_title_format: str + toolbar_format: str + redirect_formatter: TabularOutputFormatter + config: ConfigObj + my_cnf: ConfigObj + logfile: TextIOWrapper | Literal[False] | None + prompt_session: PromptSession | None + prompt_format: str + explicit_pager: bool + ptoolkit_style: _MergedStyle + helpers_style: PygmentsStyle + helpers_warnings_style: PygmentsStyle + main_formatter: TabularOutputFormatter + + def output_timing(self, timing: str, is_warnings_style: bool = False) -> None: + self.log_output(timing) + add_style = 'class:warnings.timing' if is_warnings_style else 'class:output.timing' + formatted_timing = FormattedText([('', timing)]) + styled_timing = to_formatted_text(formatted_timing, style=add_style) + prompt_toolkit.print_formatted_text(styled_timing, style=self.ptoolkit_style) + + def log_query(self, query: str) -> None: + if isinstance(self.logfile, TextIOWrapper): + self.logfile.write(f"\n# {datetime.now()}\n") + self.logfile.write(query) + self.logfile.write("\n") + + def log_output(self, output: str | AnyFormattedText) -> None: + """Log the output in the audit log, if it's enabled.""" + if isinstance(output, (ANSI, HTML, FormattedText)): + output = to_plain_text(output) + if isinstance(self.logfile, TextIOWrapper): + click.echo(output, file=self.logfile) + + def echo(self, s: str, **kwargs) -> None: + """Print a message to stdout.""" + self.log_output(s) + click.secho(s, **kwargs) + + def get_output_margin(self, status: str | None = None) -> int: + """Get the output margin for prompt, footer, timing, and status.""" + if not self.prompt_lines: + if self.prompt_session and self.prompt_session.app: + render_counter = self.prompt_session.app.render_counter + else: + render_counter = 0 + prompt_string = repl_mode.render_prompt_string(self, self.prompt_format, render_counter) + self.prompt_lines = to_plain_text(prompt_string).count('\n') + 1 + margin = self.get_reserved_space() + self.prompt_lines + if special.is_timing_enabled(): + margin += 1 + if status: + margin += 1 + status.count("\n") + + return margin + + def output( + self, + output: itertools.chain[str], + result: SQLResult, + is_warnings_style: bool = False, + ) -> None: + """Output text to stdout or a pager command.""" + if output: + if self.prompt_session is not None: + size = self.prompt_session.output.get_size() + size_columns = size.columns + size_rows = size.rows + else: + size_columns = DEFAULT_WIDTH + size_rows = DEFAULT_HEIGHT + + margin = self.get_output_margin(result.status_plain) + + fits = True + buf = [] + output_via_pager = self.explicit_pager and special.is_pager_enabled() + for i, line in enumerate(output, 1): + self.log_output(line) + special.write_tee(line) + special.write_once(line) + special.write_pipe_once(line) + + if special.is_redirected(): + pass + elif fits or output_via_pager: + buf.append(line) + if len(line) > size_columns or i > (size_rows - margin): + fits = False + if not self.explicit_pager and special.is_pager_enabled(): + output_via_pager = True + + if not output_via_pager: + for buf_line in buf: + click.secho(buf_line) + buf = [] + else: + click.secho(line) + + if buf: + if output_via_pager: + + def newlinewrapper(text: list[str]) -> Generator[str, None, None]: + for line in text: + yield line + "\n" + + click.echo_via_pager(newlinewrapper(buf)) + else: + for line in buf: + click.secho(line) + + if result.status: + self.log_output(result.status_plain) + add_style = 'class:warnings.status' if is_warnings_style else 'class:output.status' + if isinstance(result.status, FormattedText): + status = result.status + else: + status = FormattedText([('', result.status_plain)]) + styled_status = to_formatted_text(status, style=add_style) + prompt_toolkit.print_formatted_text(styled_status, style=self.ptoolkit_style) + + def configure_pager(self) -> None: + if not os.environ.get("LESS"): + os.environ["LESS"] = "-RXF" + + cnf = self.read_my_cnf(self.my_cnf, ["pager", "skip-pager"]) + cnf_pager = cnf["pager"] or self.config["main"]["pager"] + + if WIN and cnf_pager == 'less' and not shutil.which(cnf_pager): + cnf_pager = 'more' + + if cnf_pager: + special.set_pager(cnf_pager) + self.explicit_pager = True + else: + self.explicit_pager = False + + if cnf["skip-pager"] or not self.config["main"].as_bool("enable_pager"): + special.disable_pager() + + def format_sqlresult( + self, + result, + is_expanded: bool = False, + is_redirected: bool = False, + null_string: str | None = None, + numeric_alignment: str = 'right', + binary_display: str | None = None, + max_width: int | None = None, + is_warnings_style: bool = False, + ) -> itertools.chain[str]: + if is_redirected: + use_formatter = self.redirect_formatter + else: + use_formatter = self.main_formatter + + is_expanded = is_expanded or use_formatter.format_name == "vertical" + output: itertools.chain[str] = itertools.chain() + + output_kwargs = { + "dialect": "unix", + "disable_numparse": True, + "preserve_whitespace": True, + "style": self.helpers_warnings_style if is_warnings_style else self.helpers_style, + } + default_kwargs = use_formatter._output_formats[use_formatter.format_name].formatter_args + + if null_string is not None and default_kwargs.get('missing_value') == DEFAULT_MISSING_VALUE: + output_kwargs['missing_value'] = null_string + + if use_formatter.format_name not in sql_format.supported_formats and binary_display != 'utf8': + output_kwargs["preprocessors"] = (preprocessors.convert_to_undecoded_string,) + + if result.preamble: + output = itertools.chain(output, [result.preamble]) + + if result.header or (result.rows and result.preamble): + column_types = None + colalign = None + if isinstance(result.rows, Cursor): + + def get_col_type(col) -> type: + col_type = FIELD_TYPES.get(col[1], str) + return col_type if type(col_type) is type else str + + if result.rows.rowcount > 0: + column_types = [get_col_type(tup) for tup in result.rows.description] + colalign = [numeric_alignment if x in (int, float, Decimal) else 'left' for x in column_types] + else: + column_types, colalign = [], [] + + if max_width is not None and isinstance(result.rows, Cursor): + result_rows = list(result.rows) + else: + result_rows = result.rows + + formatted = use_formatter.format_output( + result_rows, + result.header or [], + format_name="vertical" if is_expanded else None, + column_types=column_types, + colalign=colalign, + **output_kwargs, + ) + + if isinstance(formatted, str): + formatted = formatted.splitlines() + formatted = iter(formatted) + + if not is_expanded and max_width and result.header and result_rows: + first_line = next(formatted) + if len(strip_ansi(first_line)) > max_width: + formatted = use_formatter.format_output( + result_rows, + result.header, + format_name="vertical", + column_types=column_types, + **output_kwargs, + ) + if isinstance(formatted, str): + formatted = iter(formatted.splitlines()) + else: + formatted = itertools.chain([first_line], formatted) + + output = itertools.chain(output, formatted) + + if result.postamble: + output = itertools.chain(output, [result.postamble]) + + return output + + def get_reserved_space(self) -> int: + """Get the number of lines to reserve for the completion menu.""" + reserved_space_ratio = 0.45 + max_reserved_space = 8 + _, height = shutil.get_terminal_size() + return min(int(round(height * reserved_space_ratio)), max_reserved_space) diff --git a/mycli/packages/batch_utils.py b/mycli/packages/batch_utils.py new file mode 100644 index 00000000..d0ebd218 --- /dev/null +++ b/mycli/packages/batch_utils.py @@ -0,0 +1,36 @@ +from typing import IO, Generator + +import sqlglot +import sqlparse + +MAX_MULTILINE_BATCH_STATEMENT = 5000 + + +def statements_from_filehandle(file_h: IO) -> Generator[tuple[str, int], None, None]: + statements = '' + line_counter = 0 + batch_counter = 0 + for batch_text in file_h: + line_counter += 1 + if line_counter > MAX_MULTILINE_BATCH_STATEMENT: + raise ValueError(f'Saw single input statement greater than {MAX_MULTILINE_BATCH_STATEMENT} lines; assuming a parsing error.') + statements += batch_text + try: + tokens = sqlglot.tokenize(statements, read='mysql') + if not tokens: + continue + # we don't yet handle changing the delimiter within the batch input + if tokens[-1].text == ';': + # The advantage of sqlparse for splitting is that it preserves the input. + # https://github.com/tobymao/sqlglot/issues/2587#issuecomment-1823109501 + for statement in sqlparse.split(statements): + yield (statement, batch_counter) + batch_counter += 1 + statements = '' + line_counter = 0 + except sqlglot.errors.TokenError: + continue + if statements: + for statement in sqlparse.split(statements): + yield (statement, batch_counter) + batch_counter += 1 diff --git a/mycli/packages/cli_utils.py b/mycli/packages/cli_utils.py new file mode 100644 index 00000000..65950130 --- /dev/null +++ b/mycli/packages/cli_utils.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +import sys + + +def filtered_sys_argv() -> list[str]: + args = sys.argv[1:] + if args == ['-h']: + args = ['--help'] + return args + + +def is_valid_connection_scheme(text: str) -> tuple[bool, str | None]: + # exit early if the text does not resemble a DSN URI + if "://" not in text: + return False, None + scheme = text.split("://")[0] + if scheme not in ("mysql", "mysqlx", "tcp", "socket", "ssh"): + return False, scheme + else: + return True, None diff --git a/mycli/packages/completion_engine.py b/mycli/packages/completion_engine.py index 2735f5b8..f623a38c 100644 --- a/mycli/packages/completion_engine.py +++ b/mycli/packages/completion_engine.py @@ -1,10 +1,691 @@ +from dataclasses import dataclass +import functools +import re +from typing import Any, Callable, Literal + import sqlparse -from sqlparse.sql import Comparison, Identifier, Where -from .parseutils import last_word, extract_tables, find_prev_keyword -from .special import parse_special_command +from sqlparse.sql import Comparison, Identifier, Token, Where + +from mycli.packages.special.main import COMMANDS as SPECIAL_COMMANDS +from mycli.packages.special.main import parse_special_command +from mycli.packages.sql_utils import extract_tables, find_prev_keyword, last_word + +sqlparse.engine.grouping.MAX_GROUPING_DEPTH = None # type: ignore[assignment] +sqlparse.engine.grouping.MAX_GROUPING_TOKENS = None # type: ignore[assignment] + +_ENUM_VALUE_RE = re.compile( + r"(?P(?:`[^`]+`|[\w$]+)(?:\.(?:`[^`]+`|[\w$]+))?)\s*=\s*$", + re.IGNORECASE, +) + +# missing because not binary +# BETWEEN +# CASE +# missing because parens are used +# IN(), and others +# unary operands might need to have another set +# not, !, ~ +# arrow operators only take a literal on the right +# and so might need different treatment +# := might also need a different context +# sqlparse would call these identifiers, so they are excluded +# xor +# these are hitting the recursion guard, and so not completing after +# so we might as well leave them out: +# is, 'is not', mod +# sqlparse might also parse "not null" together +# should also verify how sqlparse parses every space-containing case +BINARY_OPERANDS = { + '&', '>', '>>', '>=', '<', '<>', '!=', '<<', '<=', '<=>', '%', + '*', '+', '-', '->', '->>', '/', ':=', '=', '^', 'and', '&&', 'div', + 'like', 'not like', 'not regexp', 'or', '||', 'regexp', 'rlike', + 'sounds like', '|', +} # fmt: skip + +Suggestion = dict[str, Any] +Predicate = Callable[['SuggestContext'], bool] +Emitter = Callable[['SuggestContext'], list[Suggestion]] + + +@dataclass(frozen=True) +class SuggestContext: + token: str | Token | None + token_value: str | None + text_before_cursor: str + word_before_cursor: str | None + full_text: str + identifier: Identifier + parsed_cb: Callable[[], sqlparse.sql.Statement] + tokens_wo_space_cb: Callable[[], list[Token]] + + +@dataclass(frozen=True) +class SuggestRule: + name: str + predicate: Predicate + emit: Emitter + + +def _keyword_suggestions() -> list[Suggestion]: + return [{'type': 'keyword'}] + + +def _keyword_and_special_suggestions() -> list[Suggestion]: + return [{'type': 'keyword'}, {'type': 'special'}] + + +@functools.lru_cache(maxsize=128) +def _parse_suggestion_statement(text_before_cursor: str) -> sqlparse.sql.Statement: + try: + return sqlparse.parse(text_before_cursor)[0] + except (AttributeError, IndexError, ValueError, sqlparse.exceptions.SQLParseError): + return sqlparse.sql.Statement() + + +@functools.lru_cache(maxsize=128) +def _tokens_wo_space(text_before_cursor: str) -> list[Token]: + parsed = _parse_suggestion_statement(text_before_cursor) + return [x for x in parsed.tokens if x.ttype != sqlparse.tokens.Token.Text.Whitespace] + + +def _normalize_token_value(token: str | Token | None) -> str | None: + if isinstance(token, str): + return token.lower() + if isinstance(token, Comparison): + # If 'token' is a Comparison type such as + # 'select * FROM abc a JOIN def d ON a.id = d.'. Then calling + # token.value on the comparison type will only return the lhs of the + # comparison. In this case a.id. So we need to do token.tokens to get + # both sides of the comparison and pick the last token out of that + # list. + return token.tokens[-1].value.lower() + if token is None: + return None + return token.value.lower() + + +def _build_suggest_context( + token: str | Token | None, + text_before_cursor: str, + word_before_cursor: str | None, + full_text: str, + identifier: Identifier, +) -> SuggestContext: + return SuggestContext( + token=token, + token_value=_normalize_token_value(token), + text_before_cursor=text_before_cursor, + word_before_cursor=word_before_cursor, + full_text=full_text, + identifier=identifier, + parsed_cb=functools.partial(_parse_suggestion_statement, text_before_cursor), + tokens_wo_space_cb=functools.partial(_tokens_wo_space, text_before_cursor), + ) + + +def _is_single_or_double_quoted(ctx: SuggestContext) -> bool: + return is_inside_quotes(ctx.text_before_cursor, -1) in ['single', 'double'] + + +def _parent_name(ctx: SuggestContext) -> str | list[Any]: + return (ctx.identifier and ctx.identifier.get_parent_name()) or [] + + +def _tables(ctx: SuggestContext) -> list[tuple[str | None, str, str]]: + return extract_tables(ctx.full_text) + + +def _aliases(tables: list[tuple[str | None, str, str]]) -> list[str]: + return [alias or table for (schema, table, alias) in tables] + + +def _emit_none_token(_ctx: SuggestContext) -> list[Suggestion]: + return _keyword_suggestions() + + +def _emit_blank_token(_ctx: SuggestContext) -> list[Suggestion]: + return _keyword_and_special_suggestions() + + +def _emit_star(_ctx: SuggestContext) -> list[Suggestion]: + return _keyword_suggestions() + + +def _emit_lparen(ctx: SuggestContext) -> list[Suggestion]: + if ctx.parsed_cb().tokens and isinstance(ctx.parsed_cb().tokens[-1], Where): + # Four possibilities: + # 1 - Parenthesized clause like "WHERE foo AND (" + # Suggest columns/functions + # 2 - Function call like "WHERE foo(" + # Suggest columns/functions + # 3 - Subquery expression like "WHERE EXISTS (" + # Suggest keywords, in order to do a subquery + # 4 - Subquery OR array comparison like "WHERE foo = ANY(" + # Suggest columns/functions AND keywords. (If we wanted to be + # really fancy, we could suggest only array-typed columns) + + # override a few properties in the SuggestContext + column_suggestions = _emit_select_like( + SuggestContext( + token='where', + token_value='where', + text_before_cursor=ctx.text_before_cursor, + word_before_cursor=None, + full_text=ctx.full_text, + identifier=ctx.identifier, + parsed_cb=ctx.parsed_cb, + tokens_wo_space_cb=ctx.tokens_wo_space_cb, + ) + ) + + # Check for a subquery expression (cases 3 & 4) + where = ctx.parsed_cb().tokens[-1] + _idx, prev_tok = where.token_prev(len(where.tokens) - 1) + + if isinstance(prev_tok, Comparison): + # e.g. "SELECT foo FROM bar WHERE foo = ANY(" + prev_tok = prev_tok.tokens[-1] + + prev_tok = prev_tok.value.lower() + if prev_tok == 'exists': + return _keyword_suggestions() + return column_suggestions + + # Get the token before the parens + _idx, prev_tok = ctx.parsed_cb().token_prev(len(ctx.parsed_cb().tokens) - 1) + if prev_tok and prev_tok.value and prev_tok.value.lower() == 'using': + # tbl1 INNER JOIN tbl2 USING (col1, col2) + # suggest columns that are present in more than one table + return [{'type': 'column', 'tables': _tables(ctx), 'drop_unique': True}] + if ctx.parsed_cb().tokens and ctx.parsed_cb().token_first() and ctx.parsed_cb().token_first().value.lower() == 'select': + # If the lparen is preceeded by a space chances are we're about to + # do a sub-select. + if last_word(ctx.text_before_cursor, 'all_punctuations').startswith('('): + return _keyword_suggestions() + elif ctx.parsed_cb().tokens and ctx.parsed_cb().token_first() and ctx.parsed_cb().token_first().value.lower() == 'show': + return [{'type': 'show'}] + + # We're probably in a function argument list + return [{'type': 'column', 'tables': _tables(ctx)}] + + +def _emit_procedure(_ctx: SuggestContext) -> list[Suggestion]: + return [{'type': 'procedure', 'schema': []}] + + +def _emit_character_set(_ctx: SuggestContext) -> list[Suggestion]: + return [{'type': 'character_set'}] + + +def _emit_column_for_tables(ctx: SuggestContext) -> list[Suggestion]: + return [{'type': 'column', 'tables': _tables(ctx)}] + + +def _emit_nothing(_ctx: SuggestContext) -> list[Suggestion]: + return [] + + +def _emit_show(_ctx: SuggestContext) -> list[Suggestion]: + return [{'type': 'show'}] + + +def _emit_to(ctx: SuggestContext) -> list[Suggestion]: + if ctx.parsed_cb().tokens and ctx.parsed_cb().token_first() and ctx.parsed_cb().token_first().value.lower() == 'change': + return [{'type': 'change'}] + return [{'type': 'user'}] + + +def _emit_user(_ctx: SuggestContext) -> list[Suggestion]: + return [{'type': 'user'}] + + +def _emit_collation(_ctx: SuggestContext) -> list[Suggestion]: + return [{'type': 'collation'}] + + +def _emit_select_like(ctx: SuggestContext) -> list[Suggestion]: + parent = _parent_name(ctx) + tables = _tables(ctx) + if parent: + tables = [t for t in tables if identifies(parent, *t)] + return [ + {'type': 'column', 'tables': tables}, + {'type': 'table', 'schema': parent}, + {'type': 'view', 'schema': parent}, + {'type': 'function', 'schema': parent}, + ] + if is_inside_quotes(ctx.text_before_cursor, -1) == 'backtick': + # todo: this should be revised, since we complete too exuberantly within + # backticks, including keywords + aliases = _aliases(tables) + return [ + {'type': 'column', 'tables': tables}, + {'type': 'function', 'schema': []}, + {'type': 'alias', 'aliases': aliases}, + {'type': 'keyword'}, + ] + + aliases = _aliases(tables) + return [ + {'type': 'column', 'tables': tables}, + {'type': 'function', 'schema': []}, + {'type': 'introducer'}, + {'type': 'alias', 'aliases': aliases}, + ] + + +def _emit_relation_like(ctx: SuggestContext) -> list[Suggestion]: + schema = _parent_name(ctx) + is_join = bool(ctx.token_value and ctx.token_value.endswith('join') and isinstance(ctx.token, Token) and ctx.token.is_keyword) + + # Suggest tables from either the currently-selected schema or the + # public schema if no schema has been specified + table_suggestion: Suggestion = {'type': 'table', 'schema': schema} + if is_join: + table_suggestion['join'] = True + suggest: list[Suggestion] = [table_suggestion] + + if not schema: + # Suggest schemas + suggest.append({'type': 'database'}) + + # Only tables can be TRUNCATED, otherwise suggest views + if ctx.token_value != 'truncate': + suggest.append({'type': 'view', 'schema': schema}) + + return suggest + + +def _emit_relation_name(ctx: SuggestContext) -> list[Suggestion]: + rel_type = ctx.token_value + assert rel_type is not None + schema = _parent_name(ctx) + if schema: + return [{'type': rel_type, 'schema': schema}] + return [{'type': 'schema'}, {'type': rel_type, 'schema': []}] + + +def _emit_on(ctx: SuggestContext) -> list[Suggestion]: + tables = _tables(ctx) # [(schema, table, alias), ...] + parent = _parent_name(ctx) + if parent: + # "ON parent." + # parent can be either a schema name or table alias + # todo recognize and separate schema and table suggestions + # todo remove function suggestions here + tables = [t for t in tables if identifies(parent, *t)] + return [ + {'type': 'column', 'tables': tables}, + {'type': 'table', 'schema': parent}, + {'type': 'view', 'schema': parent}, + {'type': 'function', 'schema': parent}, + ] + + # ON + # Use table alias if there is one, otherwise the table name + aliases = _aliases(tables) + suggest: list[Suggestion] = [{'type': 'fk_join', 'tables': tables}, {'type': 'alias', 'aliases': aliases}] + + # The lists of 'aliases' could be empty if we're trying to complete + # a GRANT query. eg: GRANT SELECT, INSERT ON + # In that case we just suggest all schemata and all tables. + if not aliases: + suggest.append({'type': 'database'}) + suggest.append({'type': 'table', 'schema': parent}) + return suggest + + +def _emit_database(_ctx: SuggestContext) -> list[Suggestion]: + return [{'type': 'database'}] + + +def _emit_where_token(ctx: SuggestContext) -> list[Suggestion]: + assert isinstance(ctx.token, Where) + # sqlparse groups all tokens from the where clause into a single token + # list. This means that token.value may be something like + # 'where foo > 5 and '. We need to look "inside" token.tokens to handle + # suggestions in complicated where clauses correctly. + # + # This logic also needs to look even deeper in to the WHERE clause. + # We recapitulate some transcoding suggestions here, but cannot + # recapitulate the entire logic of this function. + where_tokens = [x for x in ctx.token.tokens if x.ttype != sqlparse.tokens.Token.Text.Whitespace] + if transcoding_suggestion := _charset_suggestion(where_tokens): + return transcoding_suggestion + + original_text = ctx.text_before_cursor + prev_keyword, rewound_text = find_prev_keyword(ctx.text_before_cursor) + enum_suggestion = _enum_value_suggestion(original_text, ctx.full_text) + fallback = suggest_based_on_last_token(prev_keyword, rewound_text, None, ctx.full_text, ctx.identifier) + if enum_suggestion and _is_where_or_having(prev_keyword): + return [enum_suggestion] + fallback + return fallback + + +def _emit_binary_or_comma(ctx: SuggestContext) -> list[Suggestion]: + original_text = ctx.text_before_cursor + prev_keyword, rewound_text = find_prev_keyword(ctx.text_before_cursor) + enum_suggestion = _enum_value_suggestion(original_text, ctx.full_text) + + # guard against non-progressing parser rewinds, which can otherwise + # recurse forever on some operator shapes. + if prev_keyword and rewound_text.rstrip() != original_text.rstrip(): + fallback = suggest_based_on_last_token(prev_keyword, rewound_text, None, ctx.full_text, ctx.identifier) + else: + # perhaps this fallback should include columns + fallback = _keyword_suggestions() + + if enum_suggestion and _is_where_or_having(prev_keyword): + return [enum_suggestion] + fallback + return fallback + + +def _word_starts_with_digit_or_dot(ctx: SuggestContext) -> bool: + return bool(ctx.word_before_cursor and re.match(r'^[\d\.]', ctx.word_before_cursor[0])) + + +def _word_starts_with_quote(ctx: SuggestContext) -> bool: + return bool(ctx.word_before_cursor and ctx.word_before_cursor[0] in ('"', "'")) + + +def _word_inside_single_or_double_quotes(ctx: SuggestContext) -> bool: + return bool(ctx.word_before_cursor and _is_single_or_double_quoted(ctx)) + + +def _token_is_none(ctx: SuggestContext) -> bool: + return ctx.token is None + + +def _token_is_blank(ctx: SuggestContext) -> bool: + return not ctx.token + + +def _token_value_is(ctx: SuggestContext, *values: str) -> bool: + return bool(ctx.token_value and ctx.token_value in values) + + +def _token_is_lparen(ctx: SuggestContext) -> bool: + return bool(ctx.token_value and ctx.token_value.endswith('(')) + + +def _token_is_relation_keyword(ctx: SuggestContext) -> bool: + return bool( + (ctx.token_value and ctx.token_value.endswith('join') and isinstance(ctx.token, Token) and ctx.token.is_keyword) + or (ctx.token_value in ('copy', 'from', 'update', 'into', 'describe', 'truncate', 'desc', 'explain')) + or (ctx.token_value == 'like' and re.match(r'^\s*create\s+table\s', ctx.full_text, re.IGNORECASE)) + ) + + +def _token_is_binary_or_comma(ctx: SuggestContext) -> bool: + return bool(ctx.token_value and (ctx.token_value.endswith(',') or ctx.token_value in BINARY_OPERANDS)) + + +SUGGEST_BASED_ON_LAST_TOKEN_RULES = [ + SuggestRule( + 'guard_number_or_dot', + _word_starts_with_digit_or_dot, + _emit_nothing, + ), + SuggestRule( + 'guard_quote_prefix', + _word_starts_with_quote, + _emit_nothing, + ), + SuggestRule( + 'guard_inside_single_or_double', + _word_inside_single_or_double_quotes, + _emit_nothing, + ), + SuggestRule( + 'where_token', + lambda ctx: isinstance(ctx.token, Where), + _emit_where_token, + ), + SuggestRule( + 'none_token', + _token_is_none, + _emit_none_token, + ), + SuggestRule( + 'blank_token', + _token_is_blank, + _emit_blank_token, + ), + SuggestRule( + 'star_token', + lambda ctx: _token_value_is(ctx, '*'), + _emit_star, + ), + SuggestRule( + 'lparen_token', + _token_is_lparen, + _emit_lparen, + ), + SuggestRule( + 'call', + lambda ctx: _token_value_is(ctx, 'call'), + _emit_procedure, + ), + SuggestRule( + 'character_set_after_character', + lambda ctx: ( + _token_value_is(ctx, 'set') and len(ctx.tokens_wo_space_cb()) >= 3 and ctx.tokens_wo_space_cb()[-3].value.lower() == 'character' + ), + _emit_character_set, + ), + SuggestRule( + 'character_set_after_character_short', + lambda ctx: ( + _token_value_is(ctx, 'set') and len(ctx.tokens_wo_space_cb()) >= 2 and ctx.tokens_wo_space_cb()[-2].value.lower() == 'character' + ), + _emit_character_set, + ), + SuggestRule( + 'set_order_by_distinct', + lambda ctx: _token_value_is(ctx, 'set', 'order by', 'distinct'), + _emit_column_for_tables, + ), + SuggestRule( + 'as', + lambda ctx: _token_value_is(ctx, 'as'), + _emit_nothing, + ), + SuggestRule( + 'show', + lambda ctx: _token_value_is(ctx, 'show'), + _emit_show, + ), + SuggestRule( + 'to', + lambda ctx: _token_value_is(ctx, 'to'), + _emit_to, + ), + SuggestRule( + 'user_or_for', + lambda ctx: _token_value_is(ctx, 'user', 'for'), + _emit_user, + ), + SuggestRule( + 'collate', + lambda ctx: _token_value_is(ctx, 'collate'), + _emit_collation, + ), + SuggestRule( + 'using_after_convert_long', + lambda ctx: ( + _token_value_is(ctx, 'using') and len(ctx.tokens_wo_space_cb()) >= 5 and ctx.tokens_wo_space_cb()[-5].value.lower() == 'convert' + ), + _emit_character_set, + ), + SuggestRule( + 'using_after_convert_short', + lambda ctx: ( + _token_value_is(ctx, 'using') and len(ctx.tokens_wo_space_cb()) >= 4 and ctx.tokens_wo_space_cb()[-4].value.lower() == 'convert' + ), + _emit_character_set, + ), + SuggestRule( + 'select_where_having', + lambda ctx: _token_value_is(ctx, 'select', 'where', 'having'), + _emit_select_like, + ), + SuggestRule( + 'relation_keyword', + _token_is_relation_keyword, + _emit_relation_like, + ), + SuggestRule( + 'relation_name', + lambda ctx: _token_value_is(ctx, 'table', 'view', 'function'), + _emit_relation_name, + ), + SuggestRule( + 'on', + lambda ctx: _token_value_is(ctx, 'on'), + _emit_on, + ), + SuggestRule( + 'database_template', + lambda ctx: _token_value_is(ctx, 'database', 'template'), + _emit_database, + ), + SuggestRule( + 'inside_single_or_double', + _is_single_or_double_quoted, + _emit_nothing, + ), + SuggestRule( + 'binary_or_comma', + _token_is_binary_or_comma, + _emit_binary_or_comma, + ), +] + + +def _enum_value_suggestion(text_before_cursor: str, full_text: str) -> dict[str, Any] | None: + match = _ENUM_VALUE_RE.search(text_before_cursor) + if not match: + return None + if is_inside_quotes(text_before_cursor, match.start("lhs")): + return None + + lhs = match.group("lhs") + if "." in lhs: + parent, column = lhs.split(".", 1) + else: + parent, column = None, lhs + + return { + "type": "enum_value", + "tables": extract_tables(full_text), + "column": column, + "parent": parent, + } + + +def _charset_suggestion(tokens: list[Token]) -> list[dict[str, str]] | None: + token_values = [token.value.lower() for token in tokens if token.value] + + if len(token_values) >= 2 and token_values[-1] == 'set' and token_values[-2] == 'character': + return [{'type': 'character_set'}] + if len(token_values) >= 3 and token_values[-2] == 'set' and token_values[-3] == 'character': + return [{'type': 'character_set'}] + if len(token_values) >= 5 and token_values[-1] == 'using' and token_values[-4] == 'convert': + return [{'type': 'character_set'}] + if len(token_values) >= 6 and token_values[-2] == 'using' and token_values[-5] == 'convert': + return [{'type': 'character_set'}] + if len(token_values) >= 1 and token_values[-1] == 'collate': + return [{'type': 'collation'}] + + return None + + +def _is_where_or_having(token: Token | None) -> bool: + return bool(token and token.value and token.value.lower() in ("where", "having")) + + +def _find_doubled_backticks(text: str) -> list[int]: + length = len(text) + doubled_backtick_positions: list[int] = [] + backtick = '`' + two_backticks = backtick + backtick + + if two_backticks not in text: + return doubled_backtick_positions + + for index in range(0, length): + ch = text[index] + if ch != backtick: + index += 1 + continue + if index + 1 < length and text[index + 1] == backtick: + doubled_backtick_positions.append(index) + doubled_backtick_positions.append(index + 1) + index += 2 + continue + index += 1 + + return doubled_backtick_positions + + +@functools.lru_cache(maxsize=128) +def is_inside_quotes(text: str, pos: int) -> Literal[False, 'single', 'double', 'backtick']: + in_single = False + in_double = False + in_backticks = False + escaped = False + doubled_backtick_positions = [] + single_quote = "'" + double_quote = '"' + backtick = '`' + backslash = '\\' + + # scanning the string twice seems to be needed to handle doubled backticks + doubled_backtick_positions = _find_doubled_backticks(text) + + length = len(text) + if pos < 0: + pos = length + pos + pos = max(pos, 0) + pos = min(length, pos) + + # optimization + up_to_pos = text[:pos] + if backtick not in up_to_pos and single_quote not in up_to_pos and double_quote not in up_to_pos: + return False + + for index in range(0, pos): + ch = text[index] + if index in doubled_backtick_positions: + index += 1 + continue + if escaped and (in_double or in_single): + escaped = False + index += 1 + continue + if ch == backslash and (in_double or in_single): + escaped = True + index += 1 + continue + if ch == backtick and not in_double and not in_single: + in_backticks = not in_backticks + elif ch == single_quote and not in_double and not in_backticks: + in_single = not in_single + elif ch == double_quote and not in_single and not in_backticks: + in_double = not in_double + index += 1 + + if in_single: + return 'single' + elif in_double: + return 'double' + elif in_backticks: + return 'backtick' + else: + return False -def suggest_type(full_text, text_before_cursor): +def suggest_type(full_text: str, text_before_cursor: str) -> list[dict[str, Any]]: """Takes the full_text that is typed so far and also the text before the cursor to suggest completion type and scope. @@ -12,10 +693,9 @@ def suggest_type(full_text, text_before_cursor): A scope for a column category will be a list of tables. """ - word_before_cursor = last_word(text_before_cursor, - include='many_punctuations') + word_before_cursor = last_word(text_before_cursor, include="many_punctuations") - identifier = None + identifier: Identifier | None = None # here should be removed once sqlparse has been fixed try: @@ -25,12 +705,10 @@ def suggest_type(full_text, text_before_cursor): # partially typed string which renders the smart completion useless because # it will always return the list of keywords as completion. if word_before_cursor: - if word_before_cursor.endswith( - '(') or word_before_cursor.startswith('\\'): + if word_before_cursor.endswith("(") or word_before_cursor.startswith("\\"): parsed = sqlparse.parse(text_before_cursor) else: - parsed = sqlparse.parse( - text_before_cursor[:-len(word_before_cursor)]) + parsed = sqlparse.parse(text_before_cursor[: -len(word_before_cursor)]) # word_before_cursor may include a schema qualification, like # "schema_name.partial_name" or "schema_name.", so parse it @@ -42,7 +720,7 @@ def suggest_type(full_text, text_before_cursor): else: parsed = sqlparse.parse(text_before_cursor) except (TypeError, AttributeError): - return [{'type': 'keyword'}] + return [{"type": "keyword"}] if len(parsed) > 1: # Multiple statements being edited -- isolate the current one by @@ -72,223 +750,93 @@ def suggest_type(full_text, text_before_cursor): # Be careful here because trivial whitespace is parsed as a statement, # but the statement won't have a first token tok1 = statement.token_first() - if tok1 and (tok1.value == 'source' or tok1.value.startswith('\\')): + # lenient because \. will parse as two tokens + if tok1 and tok1.value.startswith('\\'): return suggest_special(text_before_cursor) + elif tok1: + if tok1.value.lower() in SPECIAL_COMMANDS: + return suggest_special(text_before_cursor) - last_token = statement and statement.token_prev(len(statement.tokens))[1] or '' + last_token = statement and statement.token_prev(len(statement.tokens))[1] or "" - return suggest_based_on_last_token(last_token, text_before_cursor, - full_text, identifier) + # todo: unsure about empty string as identifier + return suggest_based_on_last_token(last_token, text_before_cursor, word_before_cursor, full_text, identifier or Identifier('')) -def suggest_special(text): +def suggest_special(text: str) -> list[dict[str, Any]]: text = text.lstrip() - cmd, _, arg = parse_special_command(text) + cmd, _separator, _arg = parse_special_command(text) if cmd == text: # Trying to complete the special command itself - return [{'type': 'special'}] + return [{"type": "special"}] + + if cmd in ("\\u", "\\r"): + return [{"type": "database"}] - if cmd in ('\\u', '\\r'): + if cmd.lower() in ('use', 'connect'): return [{'type': 'database'}] - if cmd in ('\\T'): - return [{'type': 'table_format'}] + if cmd in (r'\T', r'\Tr'): + return [{"type": "table_format"}] - if cmd in ['\\f', '\\fs', '\\fd']: - return [{'type': 'favoritequery'}] + if cmd.lower() in ('tableformat', 'redirectformat'): + return [{"type": "table_format"}] - if cmd in ['\\dt', '\\dt+']: + if cmd in ["\\f", "\\fs", "\\fd"]: + return [{"type": "favoritequery"}] + + if cmd in ["\\dt", "\\dt+"]: return [ - {'type': 'table', 'schema': []}, - {'type': 'view', 'schema': []}, - {'type': 'schema'}, + {"type": "table", "schema": []}, + {"type": "view", "schema": []}, + {"type": "schema"}, ] - elif cmd in ['\\.', 'source']: - return[{'type': 'file_name'}] - - return [{'type': 'keyword'}, {'type': 'special'}] - - -def suggest_based_on_last_token(token, text_before_cursor, full_text, identifier): - if isinstance(token, str): - token_v = token.lower() - elif isinstance(token, Comparison): - # If 'token' is a Comparison type such as - # 'select * FROM abc a JOIN def d ON a.id = d.'. Then calling - # token.value on the comparison type will only return the lhs of the - # comparison. In this case a.id. So we need to do token.tokens to get - # both sides of the comparison and pick the last token out of that - # list. - token_v = token.tokens[-1].value.lower() - elif isinstance(token, Where): - # sqlparse groups all tokens from the where clause into a single token - # list. This means that token.value may be something like - # 'where foo > 5 and '. We need to look "inside" token.tokens to handle - # suggestions in complicated where clauses correctly - prev_keyword, text_before_cursor = find_prev_keyword(text_before_cursor) - return suggest_based_on_last_token(prev_keyword, text_before_cursor, - full_text, identifier) - elif token is None: - return [{'type': 'keyword'}] - else: - token_v = token.value.lower() - - is_operand = lambda x: x and any([x.endswith(op) for op in ['+', '-', '*', '/']]) - - if not token: - return [{'type': 'keyword'}, {'type': 'special'}] - elif token_v.endswith('('): - p = sqlparse.parse(text_before_cursor)[0] - - if p.tokens and isinstance(p.tokens[-1], Where): - # Four possibilities: - # 1 - Parenthesized clause like "WHERE foo AND (" - # Suggest columns/functions - # 2 - Function call like "WHERE foo(" - # Suggest columns/functions - # 3 - Subquery expression like "WHERE EXISTS (" - # Suggest keywords, in order to do a subquery - # 4 - Subquery OR array comparison like "WHERE foo = ANY(" - # Suggest columns/functions AND keywords. (If we wanted to be - # really fancy, we could suggest only array-typed columns) - - column_suggestions = suggest_based_on_last_token('where', - text_before_cursor, full_text, identifier) - - # Check for a subquery expression (cases 3 & 4) - where = p.tokens[-1] - idx, prev_tok = where.token_prev(len(where.tokens) - 1) - - if isinstance(prev_tok, Comparison): - # e.g. "SELECT foo FROM bar WHERE foo = ANY(" - prev_tok = prev_tok.tokens[-1] - - prev_tok = prev_tok.value.lower() - if prev_tok == 'exists': - return [{'type': 'keyword'}] - else: - return column_suggestions - - # Get the token before the parens - idx, prev_tok = p.token_prev(len(p.tokens) - 1) - if prev_tok and prev_tok.value and prev_tok.value.lower() == 'using': - # tbl1 INNER JOIN tbl2 USING (col1, col2) - tables = extract_tables(full_text) - - # suggest columns that are present in more than one table - return [{'type': 'column', 'tables': tables, 'drop_unique': True}] - elif p.token_first().value.lower() == 'select': - # If the lparen is preceeded by a space chances are we're about to - # do a sub-select. - if last_word(text_before_cursor, - 'all_punctuations').startswith('('): - return [{'type': 'keyword'}] - elif p.token_first().value.lower() == 'show': - return [{'type': 'show'}] - - # We're probably in a function argument list - return [{'type': 'column', 'tables': extract_tables(full_text)}] - elif token_v in ('set', 'order by', 'distinct'): - return [{'type': 'column', 'tables': extract_tables(full_text)}] - elif token_v == 'as': - # Don't suggest anything for an alias - return [] - elif token_v in ('show'): - return [{'type': 'show'}] - elif token_v in ('to',): - p = sqlparse.parse(text_before_cursor)[0] - if p.token_first().value.lower() == 'change': - return [{'type': 'change'}] - else: - return [{'type': 'user'}] - elif token_v in ('user', 'for'): - return [{'type': 'user'}] - elif token_v in ('select', 'where', 'having'): - # Check for a table alias or schema qualification - parent = (identifier and identifier.get_parent_name()) or [] - - tables = extract_tables(full_text) - if parent: - tables = [t for t in tables if identifies(parent, *t)] - return [{'type': 'column', 'tables': tables}, - {'type': 'table', 'schema': parent}, - {'type': 'view', 'schema': parent}, - {'type': 'function', 'schema': parent}] - else: - aliases = [alias or table for (schema, table, alias) in tables] - return [{'type': 'column', 'tables': tables}, - {'type': 'function', 'schema': []}, - {'type': 'alias', 'aliases': aliases}, - {'type': 'keyword'}] - elif (token_v.endswith('join') and token.is_keyword) or (token_v in - ('copy', 'from', 'update', 'into', 'describe', 'truncate', - 'desc', 'explain')): - schema = (identifier and identifier.get_parent_name()) or [] - - # Suggest tables from either the currently-selected schema or the - # public schema if no schema has been specified - suggest = [{'type': 'table', 'schema': schema}] - - if not schema: - # Suggest schemas - suggest.insert(0, {'type': 'schema'}) - - # Only tables can be TRUNCATED, otherwise suggest views - if token_v != 'truncate': - suggest.append({'type': 'view', 'schema': schema}) - - return suggest - - elif token_v in ('table', 'view', 'function'): - # E.g. 'DROP FUNCTION ', 'ALTER TABLE ' - rel_type = token_v - schema = (identifier and identifier.get_parent_name()) or [] - if schema: - return [{'type': rel_type, 'schema': schema}] - else: - return [{'type': 'schema'}, {'type': rel_type, 'schema': []}] - elif token_v == 'on': - tables = extract_tables(full_text) # [(schema, table, alias), ...] - parent = (identifier and identifier.get_parent_name()) or [] - if parent: - # "ON parent." - # parent can be either a schema name or table alias - tables = [t for t in tables if identifies(parent, *t)] - return [{'type': 'column', 'tables': tables}, - {'type': 'table', 'schema': parent}, - {'type': 'view', 'schema': parent}, - {'type': 'function', 'schema': parent}] - else: - # ON - # Use table alias if there is one, otherwise the table name - aliases = [alias or table for (schema, table, alias) in tables] - suggest = [{'type': 'alias', 'aliases': aliases}] - - # The lists of 'aliases' could be empty if we're trying to complete - # a GRANT query. eg: GRANT SELECT, INSERT ON - # In that case we just suggest all tables. - if not aliases: - suggest.append({'type': 'table', 'schema': parent}) - return suggest - - elif token_v in ('use', 'database', 'template', 'connect'): - # "\c ", "DROP DATABASE ", - # "CREATE DATABASE WITH TEMPLATE " - return [{'type': 'database'}] - elif token_v == 'tableformat': - return [{'type': 'table_format'}] - elif token_v.endswith(',') or is_operand(token_v) or token_v in ['=', 'and', 'or']: - prev_keyword, text_before_cursor = find_prev_keyword(text_before_cursor) - if prev_keyword: - return suggest_based_on_last_token( - prev_keyword, text_before_cursor, full_text, identifier) - else: - return [] - else: - return [{'type': 'keyword'}] - - -def identifies(id, schema, table, alias): - return id == alias or id == table or ( - schema and (id == schema + '.' + table)) + elif cmd.lower() in [ + r'\.', + 'source', + r'\o', + r'\once', + r'tee', + ]: + return [{"type": "file_name"}] + # todo: why is \edit case-sensitive? + elif cmd in [ + r'\e', + r'\edit', + ]: + return [{"type": "file_name"}] + if cmd in ["\\llm", "\\ai"]: + return [{"type": "llm"}] + + return [{"type": "keyword"}, {"type": "special"}] + + +def suggest_based_on_last_token( + token: str | Token | None, + text_before_cursor: str, + word_before_cursor: str | None, + full_text: str, + identifier: Identifier, +) -> list[dict[str, Any]]: + ctx = _build_suggest_context(token, text_before_cursor, word_before_cursor, full_text, identifier) + for rule in SUGGEST_BASED_ON_LAST_TOKEN_RULES: + if rule.predicate(ctx): + return rule.emit(ctx) + + return _keyword_suggestions() + + +def identifies( + identifier: Any, + schema: str | None, + table: str, + alias: str, +) -> bool: + if identifier == alias: + return True + if identifier == table: + return True + if schema and identifier == (schema + "." + table): + return True + return False diff --git a/mycli/packages/filepaths.py b/mycli/packages/filepaths.py index a91055d2..5d67582c 100644 --- a/mycli/packages/filepaths.py +++ b/mycli/packages/filepaths.py @@ -1,31 +1,37 @@ import os import platform - +DEFAULT_SOCKET_DIRS: list[str] = [] if os.name == "posix": if platform.system() == "Darwin": - DEFAULT_SOCKET_DIRS = ("/tmp",) + DEFAULT_SOCKET_DIRS = ["/tmp"] else: - DEFAULT_SOCKET_DIRS = ("/var/run", "/var/lib") -else: - DEFAULT_SOCKET_DIRS = () + DEFAULT_SOCKET_DIRS = ["/var/run", "/var/lib"] -def list_path(root_dir): +def list_path(root_dir: str) -> list[str]: """List directory if exists. :param root_dir: str :return: list """ - res = [] - if os.path.isdir(root_dir): - for name in os.listdir(root_dir): - res.append(name) - return res - - -def complete_path(curr_dir, last_dir): + files = [] + dirs = [] + if not os.path.isdir(root_dir): + return [] + for name in sorted(os.listdir(root_dir)): + if name.startswith('.'): + continue + elif os.path.isdir(name): + dirs.append(f'{name}/') + # if .sql is too restrictive it can be made configurable with some effort + elif name.lower().endswith('.sql'): + files.append(name) + return files + dirs + + +def complete_path(curr_dir: str, last_dir: str) -> str: """Return the path to complete that matches the last entered component. If the last entered component is ~, expanded path would not @@ -38,11 +44,13 @@ def complete_path(curr_dir, last_dir): """ if not last_dir or curr_dir.startswith(last_dir): return curr_dir - elif last_dir == '~': + elif last_dir == "~": return os.path.join(last_dir, curr_dir) + else: + return '' -def parse_path(root_dir): +def parse_path(root_dir: str) -> tuple[str, str, int]: """Split path into head and last component for the completer. Also return position where last component starts. @@ -51,14 +59,14 @@ def parse_path(root_dir): :return: tuple of (string, string, int) """ - base_dir, last_dir, position = '', '', 0 + base_dir, last_dir, position = "", "", 0 if root_dir: base_dir, last_dir = os.path.split(root_dir) position = -len(last_dir) if last_dir else 0 return base_dir, last_dir, position -def suggest_path(root_dir): +def suggest_path(root_dir: str) -> list[str]: """List all files and subdirectories in a directory. If the directory is not specified, suggest root directory, @@ -69,9 +77,18 @@ def suggest_path(root_dir): """ if not root_dir: - return [os.path.abspath(os.sep), '~', os.curdir, os.pardir] - - if '~' in root_dir: + return [ + os.path.abspath(os.sep), + "~", + os.curdir, + os.pardir, + *list_path(os.curdir), + ] + + if root_dir[0] not in ('/', '~') and root_dir[0:1] != './': + return list_path(os.curdir) + + if "~" in root_dir: root_dir = os.path.expanduser(root_dir) if not os.path.exists(root_dir): @@ -80,7 +97,7 @@ def suggest_path(root_dir): return list_path(root_dir) -def dir_path_exists(path): +def dir_path_exists(path: str) -> bool: """Check if the directory path exists for a given file. For example, for a file /home/user/.cache/mycli/log, check if @@ -93,14 +110,14 @@ def dir_path_exists(path): return os.path.exists(os.path.dirname(path)) -def guess_socket_location(): +def guess_socket_location() -> str | None: """Try to guess the location of the default mysql socket file.""" socket_dirs = filter(os.path.exists, DEFAULT_SOCKET_DIRS) for directory in socket_dirs: for r, dirs, files in os.walk(directory, topdown=True): for filename in files: name, ext = os.path.splitext(filename) - if name.startswith("mysql") and name != "mysqlx" and ext in ('.socket', '.sock'): + if name.startswith("mysql") and name != "mysqlx" and ext in (".socket", ".sock"): return os.path.join(r, filename) dirs[:] = [d for d in dirs if d.startswith("mysql")] return None diff --git a/mycli/packages/hybrid_redirection.py b/mycli/packages/hybrid_redirection.py new file mode 100644 index 00000000..9312eea9 --- /dev/null +++ b/mycli/packages/hybrid_redirection.py @@ -0,0 +1,204 @@ +import functools +import logging + +import sqlglot + +from mycli.compat import WIN +from mycli.packages.special.delimitercommand import DelimiterCommand + +logger = logging.getLogger(__name__) +delimiter_command = DelimiterCommand() + + +def find_token_indices(tokens: list[sqlglot.Token]) -> dict[str, list[int]]: + token_indices: dict[str, list[int]] = { + 'raw_dollar': [], + 'true_dollar': [], + 'angle_bracket': [], + 'pipe': [], + } + + for i, tok in enumerate(tokens): + if tok.token_type == sqlglot.TokenType.VAR and tok.text == '$': + token_indices['raw_dollar'].append(i) + continue + if tok.token_type == sqlglot.TokenType.GT and (i - 1) in token_indices['raw_dollar']: + token_indices['angle_bracket'].append(i) + continue + if tok.token_type == sqlglot.TokenType.PIPE and (i - 1) in token_indices['raw_dollar']: + token_indices['pipe'].append(i) + continue + + for i in token_indices['raw_dollar']: + if (i + 1) in token_indices['angle_bracket'] or (i + 1) in token_indices['pipe']: + token_indices['true_dollar'].append(i) + + return token_indices + + +def find_sql_part( + command: str, + tokens: list[sqlglot.Token], + true_dollar_indices: list[int], +): + leftmost_dollar_pos = tokens[true_dollar_indices[0]].start + sql_part = command[0:leftmost_dollar_pos].strip().removesuffix(delimiter_command.current).rstrip() + try: + statements = sqlglot.parse(sql_part, read='mysql') + except sqlglot.errors.ParseError: + return '' + if len(statements) != 1: + # buglet: the statement count doesn't respect a custom delimiter + return '' + return sql_part + + +def find_command_tokens( + tokens: list[sqlglot.Token], + true_dollar_indices: list[int], +) -> list[sqlglot.Token]: + command_part_tokens = [] + + for i, tok in enumerate(tokens): + if i < true_dollar_indices[0]: + continue + if i in true_dollar_indices: + continue + command_part_tokens.append(tok) + + if command_part_tokens: + _operator = command_part_tokens.pop(0) + + return command_part_tokens + + +def find_file_tokens( + tokens: list[sqlglot.Token], + angle_bracket_indices: list[int], +) -> tuple[list[sqlglot.Token], int, str | None]: + file_part_tokens: list[sqlglot.Token] = [] + file_part_index = len(tokens) + + if not angle_bracket_indices: + return file_part_tokens, file_part_index, None + + file_part_tokens = tokens[angle_bracket_indices[-1] :] + file_part_index = angle_bracket_indices[-1] + + file_operator_part = file_part_tokens.pop(0).text + if file_operator_part == '>' and file_part_tokens[0].token_type == sqlglot.TokenType.GT: + file_part_tokens.pop(0) + file_operator_part = '>>' + + return file_part_tokens, file_part_index, file_operator_part + + +def assemble_tokens(tokens: list[sqlglot.Token]) -> str: + assembled_string = ' ' * (tokens[-1].end + 10) + for tok in tokens: + if tok.token_type == sqlglot.TokenType.IDENTIFIER: + text = f'"{tok.text}"' + offset = 2 + elif tok.token_type == sqlglot.TokenType.STRING: + text = f"'{tok.text}'" + offset = 2 + else: + text = tok.text + offset = 0 + assembled_string = assembled_string[0 : tok.start] + text + assembled_string[tok.end + offset :] + return assembled_string.strip().removesuffix(delimiter_command.current).rstrip() + + +def invalid_shell_part( + file_part: str | None, + command_part: str | None, +) -> bool: + if file_part and ' ' in file_part: + return True + + if file_part and '>' in file_part: + return True + + if not file_part and not command_part: + return True + + return False + + +# todo there are still corner cases combining custom delimiters, caching, and redirection +@functools.lru_cache(maxsize=1) +def get_redirect_components(command: str) -> tuple[str | None, str | None, str | None, str | None]: + """Get the parts of a hybrid shell-style redirect command.""" + + try: + tokens = sqlglot.tokenize(command) + except sqlglot.errors.TokenError: + return None, None, None, None + + token_indices = find_token_indices(tokens) + + if not token_indices['true_dollar']: + return None, None, None, None + + if len(token_indices['angle_bracket']) > 1: + return None, None, None, None + + if WIN and len(token_indices['pipe']) > 1: + # how to give better feedback here? + return None, None, None, None + + if token_indices['angle_bracket'] and token_indices['pipe']: + if token_indices['pipe'][-1] > token_indices['angle_bracket'][-1]: + return None, None, None, None + + sql_part = find_sql_part( + command, + tokens, + token_indices['true_dollar'], + ) + if not sql_part: + return None, None, None, None + + ( + file_part_tokens, + file_part_index, + file_operator_part, + ) = find_file_tokens( + tokens, + token_indices['angle_bracket'], + ) + + command_part_tokens = find_command_tokens( + tokens[0:file_part_index], + token_indices['true_dollar'], + ) + + if file_part_tokens: + file_part = assemble_tokens(file_part_tokens) + else: + file_part = None + + if command_part_tokens: + command_part = assemble_tokens(command_part_tokens) + else: + command_part = None + + if invalid_shell_part(file_part, command_part): + return None, None, None, None + + logger.debug('redirect parse sql_part: "{}"'.format(sql_part)) + logger.debug('redirect parse command_part: "{}"'.format(command_part)) + logger.debug('redirect parse file_operator_part: "{}"'.format(file_operator_part)) + logger.debug('redirect parse file_part: "{}"'.format(file_part)) + + return sql_part, command_part, file_operator_part, file_part + + +def is_redirect_command(command: str) -> bool: + """Is this a shell-style redirect to command or file? + + :param command: string + + """ + sql_part, _command_part, _file_operator_part, _file_part = get_redirect_components(command) + return bool(sql_part) diff --git a/mycli/packages/prompt_utils.py b/mycli/packages/interactive_utils.py similarity index 59% rename from mycli/packages/prompt_utils.py rename to mycli/packages/interactive_utils.py index fb1e431a..fa0f0537 100644 --- a/mycli/packages/prompt_utils.py +++ b/mycli/packages/interactive_utils.py @@ -1,29 +1,31 @@ import sys + import click -from .parseutils import is_destructive + +from mycli.packages.sql_utils import is_destructive class ConfirmBoolParamType(click.ParamType): - name = 'confirmation' + name = "confirmation" - def convert(self, value, param, ctx): + def convert(self, value: bool | str, param: click.Parameter | None, ctx: click.Context | None) -> bool: if isinstance(value, bool): return bool(value) value = value.lower() - if value in ('yes', 'y'): + if value in ("yes", "y"): return True - elif value in ('no', 'n'): + if value in ("no", "n"): return False - self.fail('%s is not a valid boolean' % value, param, ctx) + self.fail(f'{value} is not a valid boolean', param, ctx) def __repr__(self): - return 'BOOL' + return "BOOL" BOOLEAN_TYPE = ConfirmBoolParamType() -def confirm_destructive_query(queries): +def confirm_destructive_query(keywords: list[str], queries: str) -> bool | None: """Check if the query is destructive and prompts the user to confirm. Returns: @@ -32,13 +34,14 @@ def confirm_destructive_query(queries): * False if the query is destructive and the user doesn't want to proceed. """ - prompt_text = ("You're about to run a destructive command.\n" - "Do you want to proceed? (y/n)") - if is_destructive(queries) and sys.stdin.isatty(): + prompt_text = "You're about to run a destructive command.\nDo you want to proceed? (y/n)" + if is_destructive(keywords, queries) and sys.stdin.isatty(): return prompt(prompt_text, type=BOOLEAN_TYPE) + else: + return None -def confirm(*args, **kwargs): +def confirm(*args, **kwargs) -> bool: """Prompt for confirmation (yes/no) and handle any abort exceptions.""" try: return click.confirm(*args, **kwargs) @@ -46,7 +49,7 @@ def confirm(*args, **kwargs): return False -def prompt(*args, **kwargs): +def prompt(*args, **kwargs) -> bool: """Prompt the user for input and handle any abort exceptions.""" try: return click.prompt(*args, **kwargs) diff --git a/mycli/packages/key_binding_utils.py b/mycli/packages/key_binding_utils.py new file mode 100644 index 00000000..cdf8af6a --- /dev/null +++ b/mycli/packages/key_binding_utils.py @@ -0,0 +1,133 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Callable + +from prompt_toolkit.shortcuts import PromptSession +import sqlglot + +from mycli.packages import special +from mycli.sqlexecute import SQLExecute + +if TYPE_CHECKING: + from mycli.main import MyCli + + +def server_date(sqlexecute: SQLExecute, quoted: bool = False) -> str: + server_date_str = sqlexecute.now().strftime('%Y-%m-%d') + if quoted: + return f"'{server_date_str}'" + else: + return server_date_str + + +def server_datetime(sqlexecute: SQLExecute, quoted: bool = False) -> str: + server_datetime_str = sqlexecute.now().strftime('%Y-%m-%d %H:%M:%S') + if quoted: + return f"'{server_datetime_str}'" + else: + return server_datetime_str + + +# todo: maybe these handlers belong in a repl_handlers.py (which does not exist yet) +# \clip doesn't even have a keybinding +def handle_clip_command(mycli: 'MyCli', text: str) -> bool: + r"""A clip command is any query that is prefixed or suffixed by a + '\clip'. + + :param text: Document + :return: Boolean + + """ + + if special.clip_command(text): + query = special.get_clip_query(text) or mycli.get_last_query() + message = special.copy_query_to_clipboard(sql=query) + if message: + raise RuntimeError(message) + return True + return False + + +def handle_editor_command( + mycli: 'MyCli', + text: str, + inputhook: Callable | None, + loaded_message_fn: Callable, +) -> str: + r"""Editor command is any query that is prefixed or suffixed by a '\e'. + The reason for a while loop is because a user might edit a query + multiple times. For eg: + + "select * from \e" to edit it in vim, then come + back to the prompt with the edited query "select * from + blah where q = 'abc'\e" to edit it again. + :param text: Document + :return: Document + + """ + + while special.editor_command(text): + filename = special.get_filename(text) + query = special.get_editor_query(text) or mycli.get_last_query() + sql, message = special.open_external_editor(filename=filename, sql=query) + if message: + # Something went wrong. Raise an exception and bail. + raise RuntimeError(message) + while True: + try: + assert isinstance(mycli.prompt_session, PromptSession) + text = mycli.prompt_session.prompt( + default=sql, + inputhook=inputhook, + message=loaded_message_fn, + ) + break + except KeyboardInterrupt: + sql = "" + + continue + return text + + +def handle_prettify_binding( + mycli: 'MyCli', + text: str, +) -> str: + if not text: + return '' + try: + statements = sqlglot.parse(text, read='mysql') + except Exception: + statements = [] + if len(statements) == 1 and statements[0]: + parse_succeeded = True + pretty_text = statements[0].sql(pretty=True, pad=4, dialect='mysql') + else: + parse_succeeded = False + pretty_text = text.rstrip(';') + mycli.toolbar_error_message = 'Prettify failed to parse single statement' + if pretty_text and parse_succeeded: + pretty_text = pretty_text + ';' + return pretty_text + + +def handle_unprettify_binding( + mycli: 'MyCli', + text: str, +) -> str: + if not text: + return '' + try: + statements = sqlglot.parse(text, read='mysql') + except Exception: + statements = [] + if len(statements) == 1 and statements[0]: + parse_succeeded = True + unpretty_text = statements[0].sql(pretty=False, dialect='mysql') + else: + parse_succeeded = False + unpretty_text = text.rstrip(';') + mycli.toolbar_error_message = 'Unprettify failed to parse single statement' + if unpretty_text and parse_succeeded: + unpretty_text = unpretty_text + ';' + return unpretty_text diff --git a/mycli/packages/paramiko_stub/__init__.py b/mycli/packages/paramiko_stub/__init__.py index 045b00ea..da2eca04 100644 --- a/mycli/packages/paramiko_stub/__init__.py +++ b/mycli/packages/paramiko_stub/__init__.py @@ -1,27 +1,35 @@ """A module to import instead of paramiko when it is not available (to avoid checking for paramiko all over the place). -When paramiko is first envoked, it simply shuts down mycli, telling -user they either have to install paramiko or should not use SSH -features. +When paramiko is first invoked, this simply shuts down mycli, telling the +user they either have to install paramiko or should not use SSH features. """ class Paramiko: - def __getattr__(self, name): + def __getattr__(self, name: str) -> None: import sys from textwrap import dedent - print(dedent(""" - To enable certain SSH features you need to install paramiko: - - pip install paramiko - - It is required for the following configuration options: + + print( + dedent(""" + To enable certain SSH features you need to install ssh extras: + + pip install 'mycli[ssh]' + + or + + pip install paramiko sshtunnel + + This is required for the following command-line arguments: + --list-ssh-config --ssh-config-host --ssh-host - """)) + """), + file=sys.stderr, + ) sys.exit(1) diff --git a/mycli/packages/parseutils.py b/mycli/packages/parseutils.py deleted file mode 100644 index 3090530d..00000000 --- a/mycli/packages/parseutils.py +++ /dev/null @@ -1,266 +0,0 @@ -import re -import sqlparse -from sqlparse.sql import IdentifierList, Identifier, Function -from sqlparse.tokens import Keyword, DML, Punctuation - -cleanup_regex = { - # This matches only alphanumerics and underscores. - 'alphanum_underscore': re.compile(r'(\w+)$'), - # This matches everything except spaces, parens, colon, and comma - 'many_punctuations': re.compile(r'([^():,\s]+)$'), - # This matches everything except spaces, parens, colon, comma, and period - 'most_punctuations': re.compile(r'([^\.():,\s]+)$'), - # This matches everything except a space. - 'all_punctuations': re.compile(r'([^\s]+)$'), -} - - -def last_word(text, include='alphanum_underscore'): - r""" - Find the last word in a sentence. - - >>> last_word('abc') - 'abc' - >>> last_word(' abc') - 'abc' - >>> last_word('') - '' - >>> last_word(' ') - '' - >>> last_word('abc ') - '' - >>> last_word('abc def') - 'def' - >>> last_word('abc def ') - '' - >>> last_word('abc def;') - '' - >>> last_word('bac $def') - 'def' - >>> last_word('bac $def', include='most_punctuations') - '$def' - >>> last_word('bac \def', include='most_punctuations') - '\\\\def' - >>> last_word('bac \def;', include='most_punctuations') - '\\\\def;' - >>> last_word('bac::def', include='most_punctuations') - 'def' - """ - - if not text: # Empty string - return '' - - if text[-1].isspace(): - return '' - else: - regex = cleanup_regex[include] - matches = regex.search(text) - if matches: - return matches.group(0) - else: - return '' - - -# This code is borrowed from sqlparse example script. -# -def is_subselect(parsed): - if not parsed.is_group: - return False - for item in parsed.tokens: - if item.ttype is DML and item.value.upper() in ('SELECT', 'INSERT', - 'UPDATE', 'CREATE', 'DELETE'): - return True - return False - -def extract_from_part(parsed, stop_at_punctuation=True): - tbl_prefix_seen = False - for item in parsed.tokens: - if tbl_prefix_seen: - if is_subselect(item): - for x in extract_from_part(item, stop_at_punctuation): - yield x - elif stop_at_punctuation and item.ttype is Punctuation: - return - # Multiple JOINs in the same query won't work properly since - # "ON" is a keyword and will trigger the next elif condition. - # So instead of stooping the loop when finding an "ON" skip it - # eg: 'SELECT * FROM abc JOIN def ON abc.id = def.abc_id JOIN ghi' - elif item.ttype is Keyword and item.value.upper() == 'ON': - tbl_prefix_seen = False - continue - # An incomplete nested select won't be recognized correctly as a - # sub-select. eg: 'SELECT * FROM (SELECT id FROM user'. This causes - # the second FROM to trigger this elif condition resulting in a - # StopIteration. So we need to ignore the keyword if the keyword - # FROM. - # Also 'SELECT * FROM abc JOIN def' will trigger this elif - # condition. So we need to ignore the keyword JOIN and its variants - # INNER JOIN, FULL OUTER JOIN, etc. - elif item.ttype is Keyword and ( - not item.value.upper() == 'FROM') and ( - not item.value.upper().endswith('JOIN')): - return - else: - yield item - elif ((item.ttype is Keyword or item.ttype is Keyword.DML) and - item.value.upper() in ('COPY', 'FROM', 'INTO', 'UPDATE', 'TABLE', 'JOIN',)): - tbl_prefix_seen = True - # 'SELECT a, FROM abc' will detect FROM as part of the column list. - # So this check here is necessary. - elif isinstance(item, IdentifierList): - for identifier in item.get_identifiers(): - if (identifier.ttype is Keyword and - identifier.value.upper() == 'FROM'): - tbl_prefix_seen = True - break - -def extract_table_identifiers(token_stream): - """yields tuples of (schema_name, table_name, table_alias)""" - - for item in token_stream: - if isinstance(item, IdentifierList): - for identifier in item.get_identifiers(): - # Sometimes Keywords (such as FROM ) are classified as - # identifiers which don't have the get_real_name() method. - try: - schema_name = identifier.get_parent_name() - real_name = identifier.get_real_name() - except AttributeError: - continue - if real_name: - yield (schema_name, real_name, identifier.get_alias()) - elif isinstance(item, Identifier): - real_name = item.get_real_name() - schema_name = item.get_parent_name() - - if real_name: - yield (schema_name, real_name, item.get_alias()) - else: - name = item.get_name() - yield (None, name, item.get_alias() or name) - elif isinstance(item, Function): - yield (None, item.get_name(), item.get_name()) - -# extract_tables is inspired from examples in the sqlparse lib. -def extract_tables(sql): - """Extract the table names from an SQL statement. - - Returns a list of (schema, table, alias) tuples - - """ - parsed = sqlparse.parse(sql) - if not parsed: - return [] - - # INSERT statements must stop looking for tables at the sign of first - # Punctuation. eg: INSERT INTO abc (col1, col2) VALUES (1, 2) - # abc is the table name, but if we don't stop at the first lparen, then - # we'll identify abc, col1 and col2 as table names. - insert_stmt = parsed[0].token_first().value.lower() == 'insert' - stream = extract_from_part(parsed[0], stop_at_punctuation=insert_stmt) - return list(extract_table_identifiers(stream)) - -def find_prev_keyword(sql): - """ Find the last sql keyword in an SQL statement - - Returns the value of the last keyword, and the text of the query with - everything after the last keyword stripped - """ - if not sql.strip(): - return None, '' - - parsed = sqlparse.parse(sql)[0] - flattened = list(parsed.flatten()) - - logical_operators = ('AND', 'OR', 'NOT', 'BETWEEN') - - for t in reversed(flattened): - if t.value == '(' or (t.is_keyword and ( - t.value.upper() not in logical_operators)): - # Find the location of token t in the original parsed statement - # We can't use parsed.token_index(t) because t may be a child token - # inside a TokenList, in which case token_index thows an error - # Minimal example: - # p = sqlparse.parse('select * from foo where bar') - # t = list(p.flatten())[-3] # The "Where" token - # p.token_index(t) # Throws ValueError: not in list - idx = flattened.index(t) - - # Combine the string values of all tokens in the original list - # up to and including the target keyword token t, to produce a - # query string with everything after the keyword token removed - text = ''.join(tok.value for tok in flattened[:idx+1]) - return t, text - - return None, '' - - -def query_starts_with(query, prefixes): - """Check if the query starts with any item from *prefixes*.""" - prefixes = [prefix.lower() for prefix in prefixes] - formatted_sql = sqlparse.format(query.lower(), strip_comments=True) - return bool(formatted_sql) and formatted_sql.split()[0] in prefixes - - -def queries_start_with(queries, prefixes): - """Check if any queries start with any item from *prefixes*.""" - for query in sqlparse.split(queries): - if query and query_starts_with(query, prefixes) is True: - return True - return False - - -def query_has_where_clause(query): - """Check if the query contains a where-clause.""" - return any( - isinstance(token, sqlparse.sql.Where) - for token_list in sqlparse.parse(query) - for token in token_list - ) - - -def is_destructive(queries): - """Returns if any of the queries in *queries* is destructive.""" - keywords = ('drop', 'shutdown', 'delete', 'truncate', 'alter') - for query in sqlparse.split(queries): - if query: - if query_starts_with(query, keywords) is True: - return True - elif query_starts_with( - query, ['update'] - ) is True and not query_has_where_clause(query): - return True - - return False - - -if __name__ == '__main__': - sql = 'select * from (select t. from tabl t' - print (extract_tables(sql)) - - -def is_dropping_database(queries, dbname): - """Determine if the query is dropping a specific database.""" - result = False - if dbname is None: - return False - - def normalize_db_name(db): - return db.lower().strip('`"') - - dbname = normalize_db_name(dbname) - - for query in sqlparse.parse(queries): - keywords = [t for t in query.tokens if t.is_keyword] - if len(keywords) < 2: - continue - if keywords[0].normalized in ("DROP", "CREATE") and keywords[1].value.lower() in ( - "database", - "schema", - ): - database_token = next( - (t for t in query.tokens if isinstance(t, Identifier)), None - ) - if database_token is not None and normalize_db_name(database_token.get_name()) == dbname: - result = keywords[0].normalized == "DROP" - return result diff --git a/mycli/packages/toolkit/__init__.py b/mycli/packages/ptoolkit/__init__.py similarity index 100% rename from mycli/packages/toolkit/__init__.py rename to mycli/packages/ptoolkit/__init__.py diff --git a/mycli/packages/ptoolkit/fzf.py b/mycli/packages/ptoolkit/fzf.py new file mode 100644 index 00000000..f455edd3 --- /dev/null +++ b/mycli/packages/ptoolkit/fzf.py @@ -0,0 +1,76 @@ +import re +import shlex +from shutil import which + +from prompt_toolkit import search +from prompt_toolkit.key_binding.key_processor import KeyPressEvent +from pyfzf import FzfPrompt + +from mycli.packages.ptoolkit.history import FileHistoryWithTimestamp +from mycli.packages.ptoolkit.utils import safe_invalidate_display + + +class Fzf(FzfPrompt): + def __init__(self): + self.executable = which("fzf") + if self.executable: + super().__init__() + + def is_available(self) -> bool: + return self.executable is not None + + +def search_history( + event: KeyPressEvent, + highlight_preview: bool = False, + highlight_style: str = 'default', + incremental: bool = False, +) -> None: + buffer = event.current_buffer + history = buffer.history + + fzf = Fzf() + + if incremental or not fzf.is_available() or not isinstance(history, FileHistoryWithTimestamp): + # Fallback to default reverse incremental search + search.start_search(direction=search.SearchDirection.BACKWARD) + return + + history_items_with_timestamp = history.load_history_with_timestamp() + + formatted_history_items = [] + original_history_items = [] + seen = {} + for item, timestamp in history_items_with_timestamp: + formatted_item = re.sub(r'\s+', ' ', item) + timestamp = timestamp.split(".")[0] if "." in timestamp else timestamp + if formatted_item in seen: + continue + seen[formatted_item] = True + formatted_history_items.append(f"{timestamp} {formatted_item}") + original_history_items.append(item) + + options = [ + '--info=hidden', + '--scheme=history', + '--tiebreak=index', + '--bind=ctrl-r:up,alt-r:up', + '--preview-window=down:wrap:nohidden', + '--no-height', + ] + + if highlight_preview and which('pygmentize'): + options.append(f'--preview="printf \'%s\' {{}} | pygmentize -l mysql -P style={shlex.quote(highlight_style)}"') + else: + options.append('--preview="printf \'%s\' {}"') + + result = fzf.prompt( + formatted_history_items, + fzf_options=' '.join(options), + ) + safe_invalidate_display(event.app) + + if result: + selected_index = formatted_history_items.index(result[0]) + buffer.text = original_history_items[selected_index] + buffer.cursor_position = len(buffer.text) diff --git a/mycli/packages/toolkit/history.py b/mycli/packages/ptoolkit/history.py similarity index 67% rename from mycli/packages/toolkit/history.py rename to mycli/packages/ptoolkit/history.py index 75f4a5a2..982bc774 100644 --- a/mycli/packages/toolkit/history.py +++ b/mycli/packages/ptoolkit/history.py @@ -1,8 +1,10 @@ import os -from typing import Iterable, Union, List, Tuple +from typing import Union from prompt_toolkit.history import FileHistory +from mycli.packages.sql_utils import is_password_change + _StrOrBytesPath = Union[str, bytes, os.PathLike] @@ -15,16 +17,23 @@ def __init__(self, filename: _StrOrBytesPath) -> None: self.filename = filename super().__init__(filename) - def load_history_with_timestamp(self) -> List[Tuple[str, str]]: + def append_string(self, string: str) -> None: + "Add string to the history." + self._loaded_strings.insert(0, string) + if is_password_change(string): + return + self.store_string(string) + + def load_history_with_timestamp(self) -> list[tuple[str, str]]: """ Load history entries along with their timestamps. Returns: - List[Tuple[str, str]]: A list of tuples where each tuple contains + list[tuple[str, str]]: A list of tuples where each tuple contains a history entry and its corresponding timestamp. """ - history_with_timestamp: List[Tuple[str, str]] = [] - lines: List[str] = [] + history_with_timestamp: list[tuple[str, str]] = [] + lines: list[str] = [] timestamp: str = "" def add() -> None: @@ -34,10 +43,8 @@ def add() -> None: history_with_timestamp.append((string, timestamp)) if os.path.exists(self.filename): - with open(self.filename, "rb") as f: - for line_bytes in f: - line = line_bytes.decode("utf-8", errors="replace") - + with open(self.filename, 'r', encoding='utf-8') as f: + for line in f: if line.startswith("#"): # Extract timestamp timestamp = line[2:].strip() diff --git a/mycli/packages/ptoolkit/utils.py b/mycli/packages/ptoolkit/utils.py new file mode 100644 index 00000000..1a38bb4f --- /dev/null +++ b/mycli/packages/ptoolkit/utils.py @@ -0,0 +1,23 @@ +from prompt_toolkit.application import Application, run_in_terminal + + +def safe_invalidate_display(app: Application) -> None: + """ + fzf can confuse the terminal/app when certain values are set in + environment variable FZF_DEFAULT_OPTS. + + The same could happen after running other external programs. + + This function invalidates the prompt_toolkit display, causing a + refresh of the prompt message and pending user input, without + leading to exceptions at exit time, as the built-in + app.invalidate() does. + """ + + def print_empty_string(): + app.print_text('') + + try: + run_in_terminal(print_empty_string) + except RuntimeError: + pass diff --git a/mycli/packages/special/__init__.py b/mycli/packages/special/__init__.py index 92bcca6d..9b226b84 100644 --- a/mycli/packages/special/__init__.py +++ b/mycli/packages/special/__init__.py @@ -1,10 +1,109 @@ -__all__ = [] +from mycli.packages.special.dbcommands import ( + list_databases, + list_tables, + status, +) +from mycli.packages.special.iocommands import ( + clip_command, + close_tee, + copy_query_to_clipboard, + disable_pager, + disable_show_warnings, + editor_command, + enable_show_warnings, + flush_pipe_once_if_written, + forced_horizontal, + get_clip_query, + get_current_delimiter, + get_editor_query, + get_filename, + is_expanded_output, + is_pager_enabled, + is_redirected, + is_show_favorite_query, + is_show_warnings_enabled, + is_timing_enabled, + open_external_editor, + set_delimiter, + set_destructive_keywords, + set_expanded_output, + set_favorite_queries, + set_forced_horizontal_output, + set_pager, + set_pager_enabled, + set_redirect, + set_show_favorite_query, + set_show_warnings_enabled, + set_timing_enabled, + split_queries, + unset_once_if_written, + write_once, + write_pipe_once, + write_tee, +) +from mycli.packages.special.llm import ( + FinishIteration, + handle_llm, + is_llm_command, + sql_using_llm, +) +from mycli.packages.special.main import ( + CommandNotFound, + SpecialCommandAlias, + execute, + parse_special_command, + register_special_command, + special_command, +) -def export(defn): - """Decorator to explicitly mark functions that are exposed in a lib.""" - globals()[defn.__name__] = defn - __all__.append(defn.__name__) - return defn - -from . import dbcommands -from . import iocommands +__all__: list[str] = [ + 'CommandNotFound', + 'FinishIteration', + 'SpecialCommandAlias', + 'clip_command', + 'close_tee', + 'copy_query_to_clipboard', + 'disable_pager', + 'disable_show_warnings', + 'editor_command', + 'enable_show_warnings', + 'execute', + 'flush_pipe_once_if_written', + 'forced_horizontal', + 'get_clip_query', + 'get_current_delimiter', + 'get_editor_query', + 'get_filename', + 'handle_llm', + 'is_expanded_output', + 'is_llm_command', + 'is_pager_enabled', + 'is_redirected', + 'is_show_warnings_enabled', + 'is_timing_enabled', + 'list_databases', + 'list_tables', + 'open_external_editor', + 'parse_special_command', + 'register_special_command', + 'set_delimiter', + 'set_destructive_keywords', + 'set_expanded_output', + 'set_favorite_queries', + 'set_forced_horizontal_output', + 'set_pager', + 'set_pager_enabled', + 'set_redirect', + 'set_show_warnings_enabled', + 'set_timing_enabled', + 'set_show_favorite_query', + 'is_show_favorite_query', + 'special_command', + 'split_queries', + 'sql_using_llm', + 'status', + 'unset_once_if_written', + 'write_once', + 'write_pipe_once', + 'write_tee', +] diff --git a/mycli/packages/special/dbcommands.py b/mycli/packages/special/dbcommands.py index 5c29c555..0965efd3 100644 --- a/mycli/packages/special/dbcommands.py +++ b/mycli/packages/special/dbcommands.py @@ -1,162 +1,220 @@ import logging import os import platform -from mycli import __version__ -from mycli.packages.special import iocommands -from mycli.packages.special.utils import format_uptime -from .main import special_command, RAW_QUERY, PARSED_QUERY -from pymysql import ProgrammingError - -log = logging.getLogger(__name__) +from pymysql import ProgrammingError +from pymysql.cursors import Cursor -@special_command('\\dt', '\\dt[+] [table]', 'List or describe tables.', - arg_type=PARSED_QUERY, case_sensitive=True) -def list_tables(cur, arg=None, arg_type=PARSED_QUERY, verbose=False): +from mycli import __version__ +from mycli.packages.special import iocommands +from mycli.packages.special.main import ArgType, SpecialCommandAlias, special_command +from mycli.packages.special.utils import ( + format_uptime, + get_local_timezone, + get_server_timezone, + get_ssl_cipher, + get_ssl_version, +) +from mycli.packages.sqlresult import SQLResult + +logger = logging.getLogger(__name__) + + +@special_command( + "\\dt", + "\\dt[+] [table]", + "List or describe tables.", + arg_type=ArgType.PARSED_QUERY, + case_sensitive=True, +) +def list_tables( + cur: Cursor, + arg: str | None = None, + _arg_type: ArgType = ArgType.PARSED_QUERY, + command_verbosity: bool = False, +) -> list[SQLResult]: if arg: - query = 'SHOW FIELDS FROM {0}'.format(arg) + query = f'SHOW FIELDS FROM {arg}' else: - query = 'SHOW TABLES' - log.debug(query) + query = "SHOW TABLES" + logger.debug(query) cur.execute(query) - tables = cur.fetchall() - status = '' if cur.description: - headers = [x[0] for x in cur.description] + header = [x[0] for x in cur.description] else: - return [(None, None, None, '')] - - if verbose and arg: - query = 'SHOW CREATE TABLE {0}'.format(arg) - log.debug(query) - cur.execute(query) - status = cur.fetchone()[1] - - return [(None, tables, headers, status)] + return [SQLResult()] + # Fetch results before potentially executing another query + results = list(cur.fetchall()) if command_verbosity and arg else cur -@special_command('\\l', '\\l', 'List databases.', arg_type=RAW_QUERY, case_sensitive=True) -def list_databases(cur, **_): - query = 'SHOW DATABASES' - log.debug(query) + postamble = '' + if command_verbosity and arg: + query = f'SHOW CREATE TABLE {arg}' + logger.debug(query) + cur.execute(query) + if one := cur.fetchone(): + postamble = one[1] + + # todo missing a status line because sqlexecute.get_result was not used + return [SQLResult(header=header, rows=results, postamble=postamble)] + + +@special_command( + "\\l", + "\\l", + "List databases.", + arg_type=ArgType.RAW_QUERY, + case_sensitive=True, +) +def list_databases(cur: Cursor, **_) -> list[SQLResult]: + query = "SHOW DATABASES" + logger.debug(query) cur.execute(query) if cur.description: - headers = [x[0] for x in cur.description] - return [(None, cur, headers, '')] + header = [x[0] for x in cur.description] + # todo missing a status line because sqlexecute.get_result was not used + return [SQLResult(header=header, rows=cur)] else: - return [(None, None, None, '')] - - -@special_command('status', '\\s', 'Get status information from the server.', - arg_type=RAW_QUERY, aliases=('\\s', ), case_sensitive=True) -def status(cur, **_): - query = 'SHOW GLOBAL STATUS;' - log.debug(query) + return [SQLResult()] + + +@special_command( + "status", + "status", + "Get status information from the server.", + arg_type=ArgType.RAW_QUERY, + case_sensitive=True, + aliases=[SpecialCommandAlias("\\s", case_sensitive=True)], +) +def status(cur: Cursor, **_) -> list[SQLResult]: + query = "SHOW GLOBAL STATUS;" + logger.debug(query) try: cur.execute(query) except ProgrammingError: - # Fallback in case query fail, as it does with Mysql 4 - query = 'SHOW STATUS;' - log.debug(query) + # Fallback in case query fails, as it does with Mysql 4 + query = "SHOW STATUS;" + logger.debug(query) cur.execute(query) status = dict(cur.fetchall()) - query = 'SHOW GLOBAL VARIABLES;' - log.debug(query) + query = "SHOW GLOBAL VARIABLES;" + logger.debug(query) cur.execute(query) - variables = dict(cur.fetchall()) + global_variables = dict(cur.fetchall()) - # prepare in case keys are bytes, as with Python 3 and Mysql 4 - if (isinstance(list(variables)[0], bytes) and - isinstance(list(status)[0], bytes)): - variables = {k.decode('utf-8'): v.decode('utf-8') for k, v - in variables.items()} - status = {k.decode('utf-8'): v.decode('utf-8') for k, v - in status.items()} + query = "SHOW SESSION VARIABLES;" + logger.debug(query) + cur.execute(query) + session_variables = dict(cur.fetchall()) + + # decode in case keys are bytes, as with Mysql 4 + if global_variables and isinstance(list(global_variables)[0], bytes): + global_variables = {k.decode("utf-8"): v.decode("utf-8") for k, v in global_variables.items()} + if session_variables and isinstance(list(session_variables)[0], bytes): + session_variables = {k.decode("utf-8"): v.decode("utf-8") for k, v in session_variables.items()} + if status and isinstance(list(status)[0], bytes): + status = {k.decode("utf-8"): v.decode("utf-8") for k, v in status.items()} # Create output buffers. - title = [] + preamble = [] + header = ['Setting', 'Value'] output = [] footer = [] - title.append('--------------') + preamble.append("--------------") # Output the mycli client information. implementation = platform.python_implementation() version = platform.python_version() client_info = [] - client_info.append('mycli {0},'.format(__version__)) - client_info.append('running on {0} {1}'.format(implementation, version)) - title.append(' '.join(client_info) + '\n') + client_info.append(f'mycli {__version__}') + client_info.append(f'running on {implementation} {version}') + preamble.append(" ".join(client_info) + "\n") # Build the output that will be displayed as a table. - output.append(('Connection id:', cur.connection.thread_id())) + output.append(("Connection id:", cur.connection.thread_id())) - query = 'SELECT DATABASE(), USER();' - log.debug(query) + query = "SELECT DATABASE(), USER();" + logger.debug(query) cur.execute(query) - db, user = cur.fetchone() - if db is None: - db = '' - - output.append(('Current database:', db)) - output.append(('Current user:', user)) + if one := cur.fetchone(): + db, user = one + else: + db = "" + user = "" + output.append(("Current database:", db)) + output.append(("Current user:", user)) if iocommands.is_pager_enabled(): - if 'PAGER' in os.environ: - pager = os.environ['PAGER'] + if "PAGER" in os.environ: + pager = os.environ["PAGER"] else: - pager = 'System default' + pager = "System default" else: - pager = 'stdout' - output.append(('Current pager:', pager)) + pager = "stdout" + output.append(("Current pager:", pager)) - output.append(('Server version:', '{0} {1}'.format( - variables['version'], variables['version_comment']))) - output.append(('Protocol version:', variables['protocol_version'])) + output.append(("Using delimiter:", iocommands.get_current_delimiter())) + output.append(("Using outfile:", iocommands.tee_file.name if iocommands.tee_file else '')) - if 'unix' in cur.connection.host_info.lower(): - host_info = cur.connection.host_info + output.append(("Server version:", f'{global_variables["version"]} {global_variables["version_comment"]}')) + output.append(("Protocol version:", global_variables["protocol_version"])) + if cipher := get_ssl_cipher(cur): + output.append(('SSL:', f'Cipher in use is {cipher}')) else: - host_info = '{0} via TCP/IP'.format(cur.connection.host) - - output.append(('Connection:', host_info)) + output.append(('SSL:', '')) + output.append(('SSL/TLS version:', get_ssl_version(cur) or '')) - query = ('SELECT @@character_set_server, @@character_set_database, ' - '@@character_set_client, @@character_set_connection LIMIT 1;') - log.debug(query) - cur.execute(query) - charset = cur.fetchone() - output.append(('Server characterset:', charset[0])) - output.append(('Db characterset:', charset[1])) - output.append(('Client characterset:', charset[2])) - output.append(('Conn. characterset:', charset[3])) + if getattr(cur.connection, 'unix_socket', None): + host_info = cur.connection.host_info + else: + host_info = f'{cur.connection.host} via TCP/IP' + + output.append(("Connection:", host_info)) + + charset_spec = [ + {'name': 'Server characterset:', 'variable': 'character_set_server'}, + {'name': 'Db characterset:', 'variable': 'character_set_database'}, + {'name': 'Client characterset:', 'variable': 'character_set_client'}, + {'name': 'Conn. characterset:', 'variable': 'character_set_connection'}, + {'name': 'Result characterset:', 'variable': 'character_set_results'}, + ] + for elt in charset_spec: + if elt['variable'] in session_variables: + value = session_variables[elt['variable']] + else: + value = '' + output.append((elt['name'], value)) - if 'TCP/IP' in host_info: - output.append(('TCP port:', cur.connection.port)) + if getattr(cur.connection, 'unix_socket', None): + output.append(('UNIX socket:', global_variables['socket'])) else: - output.append(('UNIX socket:', variables['socket'])) + output.append(('TCP port:', cur.connection.port)) + + output.append(('Server timezone:', get_server_timezone(global_variables))) + output.append(('Local timezone:', get_local_timezone())) - if 'Uptime' in status: - output.append(('Uptime:', format_uptime(status['Uptime']))) + if "Uptime" in status: + output.append(("Uptime:", format_uptime(status["Uptime"]))) - if 'Threads_connected' in status: + if "Threads_connected" in status: # Print the current server statistics. stats = [] - stats.append('Connections: {0}'.format(status['Threads_connected'])) - if 'Queries' in status: - stats.append('Queries: {0}'.format(status['Queries'])) - stats.append('Slow queries: {0}'.format(status['Slow_queries'])) - stats.append('Opens: {0}'.format(status['Opened_tables'])) - if 'Flush_commands' in status: - stats.append('Flush tables: {0}'.format(status['Flush_commands'])) - stats.append('Open tables: {0}'.format(status['Open_tables'])) - if 'Queries' in status: - queries_per_second = int(status['Queries']) / int(status['Uptime']) - stats.append('Queries per second avg: {:.3f}'.format( - queries_per_second)) - stats = ' '.join(stats) - footer.append('\n' + stats) - - footer.append('--------------') - return [('\n'.join(title), output, '', '\n'.join(footer))] + stats.append(f'Connections: {status["Threads_connected"]}') + if "Queries" in status: + stats.append(f'Queries: {status["Queries"]}') + stats.append(f'Slow queries: {status["Slow_queries"]}') + stats.append(f'Opens: {status["Opened_tables"]}') + if "Flush_commands" in status: + stats.append(f'Flush tables: {status["Flush_commands"]}') + stats.append(f'Open tables: {status["Open_tables"]}') + if "Queries" in status: + queries_per_second = int(status["Queries"]) / int(status["Uptime"]) + stats.append(f'Queries per second avg: {queries_per_second:.3f}') + stats_str = " ".join(stats) + footer.append("\n" + stats_str) + + footer.append("--------------") + + return [SQLResult(preamble="\n".join(preamble), header=header, rows=output, postamble="\n".join(footer))] diff --git a/mycli/packages/special/delimitercommand.py b/mycli/packages/special/delimitercommand.py index 994b134b..cceb643d 100644 --- a/mycli/packages/special/delimitercommand.py +++ b/mycli/packages/special/delimitercommand.py @@ -1,45 +1,51 @@ +from __future__ import annotations + import re +from typing import Generator + import sqlparse +from mycli.packages.sqlresult import SQLResult + +sqlparse.engine.grouping.MAX_GROUPING_DEPTH = None # type: ignore[assignment] +sqlparse.engine.grouping.MAX_GROUPING_TOKENS = None # type: ignore[assignment] + -class DelimiterCommand(object): - def __init__(self): - self._delimiter = ';' +class DelimiterCommand: + def __init__(self) -> None: + self._delimiter = ";" - def _split(self, sql): + def _split(self, sql: str) -> list[str]: """Temporary workaround until sqlparse.split() learns about custom delimiters.""" placeholder = "\ufffc" # unicode object replacement character - if self._delimiter == ';': + if self._delimiter == ";": return sqlparse.split(sql) # We must find a string that original sql does not contain. # Most likely, our placeholder is enough, but if not, keep looking while placeholder in sql: placeholder += placeholder[0] - sql = sql.replace(';', placeholder) - sql = sql.replace(self._delimiter, ';') + sql = sql.replace(";", placeholder) + sql = sql.replace(self._delimiter, ";") split = sqlparse.split(sql) - return [ - stmt.replace(';', self._delimiter).replace(placeholder, ';') - for stmt in split - ] + return [stmt.replace(";", self._delimiter).replace(placeholder, ";") for stmt in split] - def queries_iter(self, input): + def queries_iter(self, input_str: str) -> Generator[str, None, None]: """Iterate over queries in the input string.""" - queries = self._split(input) + queries = self._split(input_str) while queries: for sql in queries: delimiter = self._delimiter sql = queries.pop(0) if sql.endswith(delimiter): trailing_delimiter = True - sql = sql.strip(delimiter) + sql = sql[: -len(delimiter)] else: trailing_delimiter = False @@ -49,12 +55,12 @@ def queries_iter(self, input): # re-split everything, and if we previously stripped # the delimiter, append it to the end if self._delimiter != delimiter: - combined_statement = ' '.join([sql] + queries) + combined_statement = " ".join([sql] + queries) if trailing_delimiter: combined_statement += delimiter queries = self._split(combined_statement)[1:] - def set(self, arg, **_): + def set(self, arg: str, **_) -> list[SQLResult]: """Change delimiter. Since `arg` is everything that follows the DELIMITER token @@ -63,18 +69,18 @@ def set(self, arg, **_): word of it. """ - match = arg and re.search(r'[^\s]+', arg) + match = arg and re.search(r"[^\s]+", arg) if not match: - message = 'Missing required argument, delimiter' - return [(None, None, None, message)] + message = "Missing required argument, delimiter" + return [SQLResult(status=message)] delimiter = match.group() - if delimiter.lower() == 'delimiter': - return [(None, None, None, 'Invalid delimiter "delimiter"')] + if delimiter.lower() == "delimiter": + return [SQLResult(status='Invalid delimiter "delimiter"')] self._delimiter = delimiter - return [(None, None, None, "Changed delimiter to {}".format(delimiter))] + return [SQLResult(status=f'Changed delimiter to {delimiter}')] @property - def current(self): + def current(self) -> str: return self._delimiter diff --git a/mycli/packages/special/favoritequeries.py b/mycli/packages/special/favoritequeries.py index 0b91400e..1233ee85 100644 --- a/mycli/packages/special/favoritequeries.py +++ b/mycli/packages/special/favoritequeries.py @@ -1,8 +1,10 @@ -class FavoriteQueries(object): +from __future__ import annotations - section_name = 'favorite_queries' - usage = ''' +class FavoriteQueries: + section_name: str = "favorite_queries" + + usage = """ Favorite Queries are a way to save frequently used queries with a short name. Examples: @@ -28,36 +30,36 @@ class FavoriteQueries(object): # Delete a favorite query. > \\fd simple - simple: Deleted -''' + simple: Deleted. +""" # Class-level variable, for convenience to use as a singleton. - instance = None + instance: FavoriteQueries - def __init__(self, config): + def __init__(self, config) -> None: self.config = config @classmethod def from_config(cls, config): return FavoriteQueries(config) - def list(self): - return self.config.get(self.section_name, []) + def list(self) -> list[str | None]: + return list(self.config.get(self.section_name, {})) - def get(self, name): + def get(self, name) -> str | None: return self.config.get(self.section_name, {}).get(name, None) - def save(self, name, query): - self.config.encoding = 'utf-8' + def save(self, name: str, query: str) -> None: + self.config.encoding = "utf-8" if self.section_name not in self.config: self.config[self.section_name] = {} self.config[self.section_name][name] = query self.config.write() - def delete(self, name): + def delete(self, name: str) -> str: try: del self.config[self.section_name][name] except KeyError: - return '%s: Not Found.' % name + return f'{name}: Not Found.' self.config.write() - return '%s: Deleted' % name + return f'{name}: Deleted.' diff --git a/mycli/packages/special/iocommands.py b/mycli/packages/special/iocommands.py index 01f3c7ba..2a29c7cf 100644 --- a/mycli/packages/special/iocommands.py +++ b/mycli/packages/special/iocommands.py @@ -1,169 +1,279 @@ -import os -import re +from __future__ import annotations + import locale import logging -import subprocess +import os +import re import shlex -from io import open +import subprocess from time import sleep +from typing import Any, Generator import click +from configobj import ConfigObj +from prompt_toolkit.formatted_text import ANSI, FormattedText, to_plain_text +from pymysql.cursors import Cursor import pyperclip import sqlparse -from . import export -from .main import special_command, NO_QUERY, PARSED_QUERY -from .favoritequeries import FavoriteQueries -from .delimitercommand import DelimiterCommand -from .utils import handle_cd_command -from mycli.packages.prompt_utils import confirm_destructive_query +from mycli.compat import WIN +from mycli.packages.interactive_utils import confirm_destructive_query +from mycli.packages.special.delimitercommand import DelimiterCommand +from mycli.packages.special.favoritequeries import FavoriteQueries +from mycli.packages.special.main import COMMANDS as SPECIAL_COMMANDS +from mycli.packages.special.main import ArgType, SpecialCommandAlias, special_command +from mycli.packages.special.main import execute as special_execute +from mycli.packages.special.utils import handle_cd_command +from mycli.packages.sqlresult import SQLResult + +sqlparse.engine.grouping.MAX_GROUPING_DEPTH = None # type: ignore[assignment] +sqlparse.engine.grouping.MAX_GROUPING_TOKENS = None # type: ignore[assignment] TIMING_ENABLED = False use_expanded_output = False +force_horizontal_output = False PAGER_ENABLED = True +SHOW_FAVORITE_QUERY = True tee_file = None once_file = None written_to_once_file = False -pipe_once_process = None -written_to_pipe_once_process = False +PIPE_ONCE: dict[str, Any] = { + 'process': None, + 'stdin': [], + 'stdout_file': None, + 'stdout_mode': None, +} delimiter_command = DelimiterCommand() +favoritequeries = FavoriteQueries(ConfigObj()) +DESTRUCTIVE_KEYWORDS: list[str] = [] +SHOW_WARNINGS_ENABLED: bool = False + +def set_favorite_queries(config): + global favoritequeries + favoritequeries = FavoriteQueries(config) -@export -def set_timing_enabled(val): + +def set_timing_enabled(val: bool) -> None: global TIMING_ENABLED TIMING_ENABLED = val -@export -def set_pager_enabled(val): + +def set_pager_enabled(val: bool) -> None: global PAGER_ENABLED PAGER_ENABLED = val -@export -def is_pager_enabled(): +def is_pager_enabled() -> bool: return PAGER_ENABLED -@export -@special_command('pager', '\\P [command]', - 'Set PAGER. Print the query results via PAGER.', - arg_type=PARSED_QUERY, aliases=('\\P', ), case_sensitive=True) -def set_pager(arg, **_): + +def set_show_favorite_query(val: bool) -> None: + global SHOW_FAVORITE_QUERY + SHOW_FAVORITE_QUERY = val + + +def is_show_favorite_query() -> bool: + return SHOW_FAVORITE_QUERY + + +def set_destructive_keywords(val: list[str]) -> None: + global DESTRUCTIVE_KEYWORDS + DESTRUCTIVE_KEYWORDS = val + + +def set_show_warnings_enabled(val: bool) -> None: + global SHOW_WARNINGS_ENABLED + SHOW_WARNINGS_ENABLED = val + + +def is_show_warnings_enabled() -> bool: + return SHOW_WARNINGS_ENABLED + + +@special_command( + 'warnings', + 'warnings', + 'Enable automatic warnings display.', + arg_type=ArgType.NO_QUERY, + case_sensitive=True, + aliases=[SpecialCommandAlias('\\W', case_sensitive=True)], +) +def enable_show_warnings() -> Generator[SQLResult, None, None]: + global SHOW_WARNINGS_ENABLED + SHOW_WARNINGS_ENABLED = True + msg = "Show warnings enabled." + yield SQLResult(status=msg) + + +@special_command( + 'nowarnings', + 'nowarnings', + 'Disable automatic warnings display.', + arg_type=ArgType.NO_QUERY, + case_sensitive=True, + aliases=[SpecialCommandAlias('\\w', case_sensitive=True)], +) +def disable_show_warnings() -> Generator[SQLResult, None, None]: + global SHOW_WARNINGS_ENABLED + SHOW_WARNINGS_ENABLED = False + msg = 'Show warnings disabled.' + yield SQLResult(status=msg) + + +@special_command( + "pager", + "pager [command]", + "Set pager to [command]. Print query results via pager.", + arg_type=ArgType.PARSED_QUERY, + case_sensitive=True, + aliases=[SpecialCommandAlias("\\P", case_sensitive=True)], +) +def set_pager(arg: str, **_) -> list[SQLResult]: if arg: - os.environ['PAGER'] = arg - msg = 'PAGER set to %s.' % arg + os.environ["PAGER"] = arg + msg = f"PAGER set to {arg}." set_pager_enabled(True) else: - if 'PAGER' in os.environ: - msg = 'PAGER set to %s.' % os.environ['PAGER'] + if "PAGER" in os.environ: + msg = f"PAGER set to {os.environ['PAGER']}." else: # This uses click's default per echo_via_pager. - msg = 'Pager enabled.' + msg = "Pager enabled." set_pager_enabled(True) - return [(None, None, None, msg)] + return [SQLResult(status=msg)] + -@export -@special_command('nopager', '\\n', 'Disable pager, print to stdout.', - arg_type=NO_QUERY, aliases=('\\n', ), case_sensitive=True) -def disable_pager(): +@special_command( + "nopager", + "nopager", + "Disable pager; print to stdout.", + arg_type=ArgType.NO_QUERY, + case_sensitive=True, + aliases=[SpecialCommandAlias("\\n", case_sensitive=True)], +) +def disable_pager() -> list[SQLResult]: set_pager_enabled(False) - return [(None, None, None, 'Pager disabled.')] + return [SQLResult(status="Pager disabled.")] + -@special_command('\\timing', '\\t', 'Toggle timing of commands.', arg_type=NO_QUERY, aliases=('\\t', ), case_sensitive=True) -def toggle_timing(): +@special_command( + "\\timing", + "\\timing", + "Toggle timing of queries.", + arg_type=ArgType.NO_QUERY, + case_sensitive=True, + aliases=[SpecialCommandAlias("\\t", case_sensitive=True)], +) +def toggle_timing() -> list[SQLResult]: global TIMING_ENABLED TIMING_ENABLED = not TIMING_ENABLED message = "Timing is " message += "on." if TIMING_ENABLED else "off." - return [(None, None, None, message)] + return [SQLResult(status=message)] -@export -def is_timing_enabled(): + +def is_timing_enabled() -> bool: return TIMING_ENABLED -@export -def set_expanded_output(val): + +def set_expanded_output(val: bool) -> None: global use_expanded_output use_expanded_output = val -@export -def is_expanded_output(): + +def is_expanded_output() -> bool: return use_expanded_output + +def set_forced_horizontal_output(val: bool) -> None: + global force_horizontal_output + force_horizontal_output = val + + +def forced_horizontal() -> bool: + return force_horizontal_output + + _logger = logging.getLogger(__name__) -@export -def editor_command(command): + +def editor_command(command: str) -> bool: """ Is this an external editor command? :param command: string """ + # special case: allow help on the \edit command + if re.match(r'^([Hh][Ee][Ll][Pp])\s+(\\e|\\edit)\s*(;|\\G|\\g)?\s*$', command): + return False # It is possible to have `\e filename` or `SELECT * FROM \e`. So we check # for both conditions. - return command.strip().endswith('\\e') or command.strip().startswith('\\e') + return ( + command.strip().endswith("\\e") + or command.strip().startswith("\\e ") + or command.strip().endswith("\\edit") + or command.strip().startswith("\\edit ") + ) + -@export -def get_filename(sql): - if sql.strip().startswith('\\e'): - command, _, filename = sql.partition(' ') +def get_filename(sql: str) -> str | None: + if sql.strip().startswith("\\e ") or sql.strip().startswith("\\edit "): + command, _, filename = sql.partition(" ") return filename.strip() or None + else: + return None -@export -def get_editor_query(sql): +def get_editor_query(sql: str) -> str: """Get the query part of an editor command.""" sql = sql.strip() # The reason we can't simply do .strip('\e') is that it strips characters, # not a substring. So it'll strip "e" in the end of the sql also! # Ex: "select * from style\e" -> "select * from styl". - pattern = re.compile(r'(^\\e|\\e$)') + pattern = re.compile(r"(\\e$|\\edit$)") while pattern.search(sql): - sql = pattern.sub('', sql) + sql = pattern.sub("", sql) return sql -@export -def open_external_editor(filename=None, sql=None): +def open_external_editor(filename: str | None = None, sql: str | None = None) -> tuple[str, str | None]: """Open external editor, wait for the user to type in their query, return the query. - - :return: list with one tuple, query as first element. - """ - message = None - filename = filename.strip().split(' ', 1)[0] if filename else None - - sql = sql or '' - MARKER = '# Type your query above this line.\n' - - # Populate the editor buffer with the partial sql (if available) and a - # placeholder comment. - query = click.edit(u'{sql}\n\n{marker}'.format(sql=sql, marker=MARKER), - filename=filename, extension='.sql') + filename = filename.strip().split(" ", 1)[0] if filename else None + sql = sql or "" + MARKER = "# Type your query above this line.\n" if filename: + query = '' + message = None + click.edit(filename=filename) try: - with open(filename) as f: + with open(filename, 'r') as f: query = f.read() except IOError: - message = 'Error reading file: %s.' % filename + message = f'Error reading file: {filename}' + return (query.rstrip('\n'), message) - if query is not None: - query = query.split(MARKER, 1)[0].rstrip('\n') + # Populate the editor buffer with the partial sql (if available) and a + # placeholder comment. + query = click.edit(f"{sql}\n\n{MARKER}", extension=".sql") or '' + + if query: + query = query.split(MARKER, 1)[0].rstrip("\n") else: # Don't return None for the caller to deal with. # Empty string is ok. query = sql - return (query, message) + return (query, None) -@export -def clip_command(command): +def clip_command(command: str) -> bool: """Is this a clip command? :param command: string @@ -171,336 +281,450 @@ def clip_command(command): """ # It is possible to have `\clip` or `SELECT * FROM \clip`. So we check # for both conditions. - return command.strip().endswith('\\clip') or command.strip().startswith('\\clip') + return command.strip().endswith("\\clip") or command.strip().startswith("\\clip") -@export -def get_clip_query(sql): +def get_clip_query(sql: str) -> str: """Get the query part of a clip command.""" sql = sql.strip() # The reason we can't simply do .strip('\clip') is that it strips characters, # not a substring. So it'll strip "c" in the end of the sql also! - pattern = re.compile(r'(^\\clip|\\clip$)') + pattern = re.compile(r"(^\\clip|\\clip$)") while pattern.search(sql): - sql = pattern.sub('', sql) + sql = pattern.sub("", sql) return sql -@export -def copy_query_to_clipboard(sql=None): +def copy_query_to_clipboard(sql: str | None = None) -> str | None: """Send query to the clipboard.""" - sql = sql or '' + sql = sql or "" message = None try: - pyperclip.copy(u'{sql}'.format(sql=sql)) + pyperclip.copy(f"{sql}") except RuntimeError as e: - message = 'Error clipping query: %s.' % e.strerror + message = f"Error clipping query: {e}." return message -@special_command('\\f', '\\f [name [args..]]', 'List or execute favorite queries.', arg_type=PARSED_QUERY, case_sensitive=True) -def execute_favorite_query(cur, arg, **_): - """Returns (title, rows, headers, status)""" - if arg == '': - for result in list_favorite_queries(): - yield result +def set_redirect(command_part: str | None, file_operator_part: str | None, file_part: str | None) -> list[tuple]: + if command_part: + if file_part: + PIPE_ONCE['stdout_file'] = file_part + PIPE_ONCE['stdout_mode'] = 'w' if file_operator_part == '>' else 'a' + return set_pipe_once(command_part) + elif file_operator_part == '>': + return set_once(f'-o {file_part}') + else: + return set_once(file_part) + + +@special_command( + "\\f", + "\\f [name [args..]]", + "List or execute favorite queries.", + arg_type=ArgType.PARSED_QUERY, + case_sensitive=True, +) +def execute_favorite_query(cur: Cursor, arg: str, **_) -> Generator[SQLResult, None, None]: + if arg == "": + yield from list_favorite_queries() + return - """Parse out favorite name and optional substitution parameters""" - name, _, arg_str = arg.partition(' ') + # Parse out favorite name and optional substitution parameters + name, _separator, arg_str = arg.partition(" ") args = shlex.split(arg_str) query = FavoriteQueries.instance.get(name) if query is None: - message = "No favorite query: %s" % (name) - yield (None, None, None, message) + message = f"No favorite query: {name}" + yield SQLResult(status=message) else: query, arg_error = subst_favorite_query_args(query, args) - if arg_error: - yield (None, None, None, arg_error) + if query is None: + yield SQLResult(status=arg_error) else: for sql in sqlparse.split(query): - sql = sql.rstrip(';') - title = '> %s' % (sql) - cur.execute(sql) - if cur.description: - headers = [x[0] for x in cur.description] - yield (title, cur, headers, None) + sql = sql.rstrip(";") + preamble = f"> {sql}" if is_show_favorite_query() else None + is_special = False + for special in SPECIAL_COMMANDS: + if sql.lower().startswith(special.lower()): + is_special = True + break + if is_special: + for result in special_execute(cur, sql): + result.preamble = preamble + # special_execute() already returns a SQLResult + yield result else: - yield (title, None, None, None) + cur.execute(sql) + if cur.description: + header = [x[0] for x in cur.description] + yield SQLResult(preamble=preamble, header=header, rows=cur) + else: + yield SQLResult(preamble=preamble) + -def list_favorite_queries(): - """List of all favorite queries. - Returns (title, rows, headers, status)""" +def list_favorite_queries() -> list[SQLResult]: + """List of all favorite queries.""" - headers = ["Name", "Query"] - rows = [(r, FavoriteQueries.instance.get(r)) - for r in FavoriteQueries.instance.list()] + header = ["Name", "Query"] + rows = [(r, FavoriteQueries.instance.get(r)) for r in FavoriteQueries.instance.list()] if not rows: - status = '\nNo favorite queries found.' + FavoriteQueries.instance.usage + status = "\nNo favorite queries found." + FavoriteQueries.instance.usage else: - status = '' - return [('', rows, headers, status)] + status = "" + return [SQLResult(header=header, rows=rows, status=status)] -def subst_favorite_query_args(query, args): +def subst_favorite_query_args(query: str, args: list[str]) -> list[str | None]: """replace positional parameters ($1...$N) in query.""" for idx, val in enumerate(args): - subst_var = '$' + str(idx + 1) + subst_var = "$" + str(idx + 1) if subst_var not in query: - return [None, 'query does not have substitution parameter ' + subst_var + ':\n ' + query] + return [None, "query does not have substitution parameter " + subst_var + ":\n " + query] query = query.replace(subst_var, val) - match = re.search(r'\$\d+', query) + match = re.search(r"\$\d+", query) if match: - return[None, 'missing substitution for ' + match.group(0) + ' in query:\n ' + query] + return [None, "missing substitution for " + match.group(0) + " in query:\n " + query] return [query, None] -@special_command('\\fs', '\\fs name query', 'Save a favorite query.') -def save_favorite_query(arg, **_): - """Save a new favorite query. - Returns (title, rows, headers, status)""" - usage = 'Syntax: \\fs name query.\n\n' + FavoriteQueries.instance.usage +@special_command( + "\\fs", + "\\fs ", + "Save a favorite query.", +) +def save_favorite_query(arg: str, **_) -> list[SQLResult]: + """Save a new favorite query.""" + + usage = "Syntax: \\fs name query.\n\n" + FavoriteQueries.instance.usage if not arg: - return [(None, None, None, usage)] + return [SQLResult(status=usage)] - name, _, query = arg.partition(' ') + name, _separator, query = arg.partition(" ") # If either name or query is missing then print the usage and complain. if (not name) or (not query): - return [(None, None, None, - usage + 'Err: Both name and query are required.')] + return [SQLResult(status=f"{usage} Err: Both name and query are required.")] FavoriteQueries.instance.save(name, query) - return [(None, None, None, "Saved.")] + return [SQLResult(status="Saved.")] -@special_command('\\fd', '\\fd [name]', 'Delete a favorite query.') -def delete_favorite_query(arg, **_): +@special_command( + "\\fd", + "\\fd ", + "Delete a favorite query.", +) +def delete_favorite_query(arg: str, **_) -> list[SQLResult]: """Delete an existing favorite query.""" - usage = 'Syntax: \\fd name.\n\n' + FavoriteQueries.instance.usage + usage = "Syntax: \\fd name.\n\n" + FavoriteQueries.instance.usage if not arg: - return [(None, None, None, usage)] + return [SQLResult(status=usage)] status = FavoriteQueries.instance.delete(arg) - return [(None, None, None, status)] + return [SQLResult(status=status)] -@special_command('system', 'system [command]', - 'Execute a system shell commmand.') -def execute_system_command(arg, **_): +@special_command( + "system", + "system [-r] ", + "Execute a system shell command (raw mode with -r).", +) +def execute_system_command(arg: str, **_) -> list[SQLResult]: """Execute a system shell command.""" - usage = "Syntax: system [command].\n" + usage = "Syntax: system [-r] [command].\n-r denotes \"raw\" mode, in which output is passed through without formatting." - if not arg: - return [(None, None, None, usage)] + IMPLICIT_RAW_MODE_COMMANDS = { + 'clear', + 'vim', + 'vi', + 'bash', + 'zsh', + } + + if not arg.strip(): + return [SQLResult(status=usage)] try: - command = arg.strip() - if command.startswith('cd'): - ok, error_message = handle_cd_command(arg) - if not ok: - return [(None, None, None, error_message)] - return [(None, None, None, '')] - - args = arg.split(' ') - process = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE) - output, error = process.communicate() - response = output if not error else error - - # Python 3 returns bytes. This needs to be decoded to a string. - if isinstance(response, bytes): - encoding = locale.getpreferredencoding(False) - response = response.decode(encoding) + command = shlex.split(arg.strip(), posix=not WIN) + except ValueError as e: + return [SQLResult(status=f"Cannot parse system command: {e}")] + + raw = False + if command[0] == '-r': + command.pop(0) + raw = True + elif command[0].lower() in IMPLICIT_RAW_MODE_COMMANDS: + raw = True + + if not command: + return [SQLResult(status=usage)] + + if command[0].lower() == 'cd': + ok, error_message = handle_cd_command(command) + if not ok: + return [SQLResult(status=error_message)] + return [SQLResult()] - return [(None, None, None, response)] + try: + if raw: + completed_process = subprocess.run(command, check=False) + if completed_process.returncode: + return [SQLResult(status=f'Command exited with return code {completed_process.returncode}')] + else: + return [SQLResult()] + else: + process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + try: + output, error = process.communicate(timeout=60) + except subprocess.TimeoutExpired: + process.kill() + output, error = process.communicate() + response = output if not error else error + encoding = locale.getpreferredencoding(False) + response_str = response.decode(encoding) + if process.returncode: + status = f'Command exited with return code {process.returncode}' + else: + status = None + return [SQLResult(preamble=response_str, status=status)] except OSError as e: - return [(None, None, None, 'OSError: %s' % e.strerror)] + return [SQLResult(status=f"OSError: {e.strerror}")] -def parseargfile(arg): - if arg.startswith('-o '): +def parseargfile(arg: str) -> tuple[str, str]: + if arg.startswith("-o "): mode = "w" filename = arg[3:] else: - mode = 'a' + mode = "a" filename = arg if not filename: - raise TypeError('You must provide a filename.') + raise TypeError("You must provide a filename.") - return {'file': os.path.expanduser(filename), 'mode': mode} + return (os.path.expanduser(filename), mode) -@special_command('tee', 'tee [-o] filename', - 'Append all results to an output file (overwrite using -o).') -def set_tee(arg, **_): +@special_command( + "tee", + "tee [-o] ", + "Append all results to an output file (overwrite using -o).", +) +def set_tee(arg: str, **_) -> list[SQLResult]: global tee_file try: - tee_file = open(**parseargfile(arg)) + tee_file = open(*parseargfile(arg)) except (IOError, OSError) as e: - raise OSError("Cannot write to file '{}': {}".format(e.filename, e.strerror)) + raise OSError(f"Cannot write to file '{e.filename}': {e.strerror}") from e - return [(None, None, None, "")] + return [SQLResult(status="")] -@export -def close_tee(): + +def close_tee() -> None: global tee_file if tee_file: tee_file.close() tee_file = None -@special_command('notee', 'notee', 'Stop writing results to an output file.') -def no_tee(arg, **_): +@special_command( + "notee", + "notee", + "Stop writing results to an output file.", +) +def no_tee(arg: str, **_) -> list[SQLResult]: close_tee() - return [(None, None, None, "")] + return [SQLResult(status="")] + -@export -def write_tee(output): +def write_tee(output: str | ANSI | FormattedText, nl: bool = True) -> None: global tee_file - if tee_file: - click.echo(output, file=tee_file, nl=False) - click.echo(u'\n', file=tee_file, nl=False) - tee_file.flush() + if not tee_file: + return + click.echo(to_plain_text(output), file=tee_file, nl=False) + if nl: + click.echo('\n', file=tee_file, nl=False) + tee_file.flush() -@special_command('\\once', '\\o [-o] filename', - 'Append next result to an output file (overwrite using -o).', - aliases=('\\o', )) -def set_once(arg, **_): +@special_command( + "\\once", + "\\once [-o] ", + "Append next result to an output file (overwrite using -o).", + aliases=[SpecialCommandAlias("\\o", case_sensitive=False)], +) +def set_once(arg: str, **_) -> list[SQLResult]: global once_file, written_to_once_file try: - once_file = open(**parseargfile(arg)) + once_file = open(*parseargfile(arg)) except (IOError, OSError) as e: - raise OSError("Cannot write to file '{}': {}".format( - e.filename, e.strerror)) + raise OSError(f"Cannot write to file '{e.filename}': {e.strerror}") from e written_to_once_file = False - return [(None, None, None, "")] + return [SQLResult(status="")] + +def is_redirected() -> bool: + return bool(once_file or PIPE_ONCE['process']) -@export -def write_once(output): + +def write_once(output: str) -> None: global once_file, written_to_once_file if output and once_file: click.echo(output, file=once_file, nl=False) - click.echo(u"\n", file=once_file, nl=False) + click.echo("\n", file=once_file, nl=False) once_file.flush() written_to_once_file = True -@export -def unset_once_if_written(): +def unset_once_if_written(post_redirect_command: str) -> None: """Unset the once file, if it has been written to.""" global once_file, written_to_once_file if written_to_once_file and once_file: + once_filename = once_file.name once_file.close() once_file = None + _run_post_redirect_hook(post_redirect_command, once_filename) -@special_command('\\pipe_once', '\\| command', - 'Send next result to a subprocess.', - aliases=('\\|', )) -def set_pipe_once(arg, **_): - global pipe_once_process, written_to_pipe_once_process - pipe_once_cmd = shlex.split(arg) - if len(pipe_once_cmd) == 0: +def _run_post_redirect_hook(post_redirect_command: str, filename: str) -> None: + if not post_redirect_command: + return + post_cmd = post_redirect_command.format(shlex.quote(filename)) + try: + subprocess.run( + post_cmd, + shell=True, + check=True, + stdin=subprocess.DEVNULL, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + except Exception as e: + raise OSError(f"Redirect post hook failed: {e}") from e + + +@special_command( + "\\pipe_once", + "\\pipe_once ", + "Send next result to a subprocess.", + aliases=[SpecialCommandAlias("\\|", case_sensitive=False)], +) +def set_pipe_once(arg: str, **_) -> list[SQLResult]: + if not arg: raise OSError("pipe_once requires a command") - written_to_pipe_once_process = False - pipe_once_process = subprocess.Popen(pipe_once_cmd, - stdin=subprocess.PIPE, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - bufsize=1, - encoding='UTF-8', - universal_newlines=True) - return [(None, None, None, "")] - - -@export -def write_pipe_once(output): - global pipe_once_process, written_to_pipe_once_process - if output and pipe_once_process: - try: - click.echo(output, file=pipe_once_process.stdin, nl=False) - click.echo(u"\n", file=pipe_once_process.stdin, nl=False) - except (IOError, OSError) as e: - pipe_once_process.terminate() - raise OSError( - "Failed writing to pipe_once subprocess: {}".format(e.strerror)) - written_to_pipe_once_process = True - - -@export -def unset_pipe_once_if_written(): - """Unset the pipe_once cmd, if it has been written to.""" - global pipe_once_process, written_to_pipe_once_process - if written_to_pipe_once_process: - (stdout_data, stderr_data) = pipe_once_process.communicate() - if len(stdout_data) > 0: - print(stdout_data.rstrip(u"\n")) - if len(stderr_data) > 0: - print(stderr_data.rstrip(u"\n")) - pipe_once_process = None - written_to_pipe_once_process = False + if WIN: + # best effort, no chaining + pipe_once_cmd = shlex.split(arg) + else: + # to support chaining + pipe_once_cmd = ['sh', '-c', arg] + PIPE_ONCE['stdin'] = [] + PIPE_ONCE['process'] = subprocess.Popen( + pipe_once_cmd, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + encoding="UTF-8", + universal_newlines=True, + ) + return [SQLResult(status="")] + + +def write_pipe_once(line: str) -> None: + if line and PIPE_ONCE['process']: + PIPE_ONCE['stdin'].append(line) + + +def flush_pipe_once_if_written(post_redirect_command: str) -> None: + """Flush the pipe_once cmd, if lines have been written.""" + if not PIPE_ONCE['process']: + return + if not PIPE_ONCE['stdin']: + return + try: + (stdout_data, stderr_data) = PIPE_ONCE['process'].communicate(input='\n'.join(PIPE_ONCE['stdin']) + '\n', timeout=60) + except subprocess.TimeoutExpired: + PIPE_ONCE['process'].kill() + (stdout_data, stderr_data) = PIPE_ONCE['process'].communicate() + if stdout_data: + if PIPE_ONCE['stdout_file']: + with open(PIPE_ONCE['stdout_file'], PIPE_ONCE['stdout_mode']) as f: + print(stdout_data, file=f) + _run_post_redirect_hook(post_redirect_command, PIPE_ONCE['stdout_file']) + else: + click.secho(stdout_data.rstrip('\n')) + if stderr_data: + click.secho(stderr_data.rstrip('\n'), err=True, fg='red') + if returncode := PIPE_ONCE['process'].returncode: + PIPE_ONCE['process'] = None + PIPE_ONCE['stdin'] = [] + PIPE_ONCE['stdout_file'] = None + PIPE_ONCE['stdout_mode'] = None + raise OSError(f'process exited with nonzero code {returncode}') + PIPE_ONCE['process'] = None + PIPE_ONCE['stdin'] = [] + PIPE_ONCE['stdout_file'] = None + PIPE_ONCE['stdout_mode'] = None @special_command( - 'watch', - 'watch [seconds] [-c] query', - 'Executes the query every [seconds] seconds (by default 5).' + "watch", + "watch [seconds] [-c] ", + "Execute query every [seconds] seconds (5 by default).", ) -def watch_query(arg, **kwargs): +def watch_query(arg: str, **kwargs) -> Generator[SQLResult, None, None]: usage = """Syntax: watch [seconds] [-c] query. * seconds: The interval at the query will be repeated, in seconds. By default 5. * -c: Clears the screen between every iteration. """ if not arg: - yield (None, None, None, usage) + yield SQLResult(status=usage) return - seconds = 5 + seconds = 5.0 clear_screen = False statement = None while statement is None: arg = arg.strip() if not arg: # Oops, we parsed all the arguments without finding a statement - yield (None, None, None, usage) + yield SQLResult(status=usage) return - (current_arg, _, arg) = arg.partition(' ') + (left_arg, _, right_arg) = arg.partition(" ") + arg = right_arg try: - seconds = float(current_arg) + seconds = float(left_arg) continue except ValueError: pass - if current_arg == '-c': + if left_arg == "-c": clear_screen = True continue - statement = '{0!s} {1!s}'.format(current_arg, arg) - destructive_prompt = confirm_destructive_query(statement) + statement = f"{left_arg} {arg}" + destructive_prompt = confirm_destructive_query(DESTRUCTIVE_KEYWORDS, statement) if destructive_prompt is False: click.secho("Wise choice!") return elif destructive_prompt is True: click.secho("Your call!") - cur = kwargs['cur'] - sql_list = [ - (sql.rstrip(';'), "> {0!s}".format(sql)) - for sql in sqlparse.split(statement) - ] + cur = kwargs["cur"] + sql_list = [(sql.rstrip(";"), f"> {sql}") for sql in sqlparse.split(statement)] old_pager_enabled = is_pager_enabled() while True: if clear_screen: @@ -509,13 +733,17 @@ def watch_query(arg, **kwargs): # Somewhere in the code the pager its activated after every yield, # so we disable it in every iteration set_pager_enabled(False) - for (sql, title) in sql_list: + for sql, preamble in sql_list: cur.execute(sql) + command: dict[str, str | float] = { + "name": "watch", + "seconds": seconds, + } if cur.description: - headers = [x[0] for x in cur.description] - yield (title, cur, headers, None) + header = [x[0] for x in cur.description] + yield SQLResult(preamble=preamble, header=header, rows=cur, command=command) else: - yield (title, None, None, None) + yield SQLResult(preamble=preamble, command=command) sleep(seconds) except KeyboardInterrupt: # This prints the Ctrl-C character in its own line, which prevents @@ -526,18 +754,18 @@ def watch_query(arg, **kwargs): set_pager_enabled(old_pager_enabled) -@export -@special_command('delimiter', None, 'Change SQL delimiter.') -def set_delimiter(arg, **_): +@special_command( + "delimiter", + "delimiter ", + "Change end-of-statement delimiter.", +) +def set_delimiter(arg: str, **_) -> list[SQLResult]: return delimiter_command.set(arg) -@export -def get_current_delimiter(): +def get_current_delimiter() -> str: return delimiter_command.current -@export -def split_queries(input): - for query in delimiter_command.queries_iter(input): - yield query +def split_queries(input_str: str) -> Generator[str, None, None]: + yield from delimiter_command.queries_iter(input_str) diff --git a/mycli/packages/special/llm.py b/mycli/packages/special/llm.py new file mode 100644 index 00000000..e7786092 --- /dev/null +++ b/mycli/packages/special/llm.py @@ -0,0 +1,404 @@ +import contextlib +import functools +import io +import logging +import os +import re +from runpy import run_module +import shlex +import sys +from time import time +from typing import Any + +import click + +try: + if not os.environ.get('MYCLI_LLM_OFF'): + import llm + + LLM_IMPORTED = True + else: + LLM_IMPORTED = False +except ImportError: + LLM_IMPORTED = False +try: + if not os.environ.get('MYCLI_LLM_OFF'): + from llm.cli import cli + + LLM_CLI_IMPORTED = True + else: + LLM_CLI_IMPORTED = False +except ImportError: + LLM_CLI_IMPORTED = False +from pymysql.cursors import Cursor + +from mycli.packages.special.main import CommandVerbosity, parse_special_command +from mycli.packages.sqlresult import SQLResult + +log = logging.getLogger(__name__) + +LLM_TEMPLATE_NAME = "mycli-llm-template" + +SCHEMA_DATA_CACHE: dict[str, str] = {} + +SAMPLE_DATA_CACHE: dict[str, dict] = {} + + +def run_external_cmd( + cmd: str, + *args, + capture_output=False, + restart_cli=False, + raise_exception=True, +) -> tuple[int, str]: + original_exe = sys.executable + original_args = sys.argv + try: + sys.argv = [cmd] + list(args) + code = 0 + if capture_output: + buffer = io.StringIO() + redirect: contextlib.ExitStack[bool | None] | contextlib.nullcontext[None] = contextlib.ExitStack() + assert isinstance(redirect, contextlib.ExitStack) + redirect.enter_context(contextlib.redirect_stdout(buffer)) + redirect.enter_context(contextlib.redirect_stderr(buffer)) + else: + redirect = contextlib.nullcontext() + with redirect: + try: + run_module(cmd, run_name="__main__") + except SystemExit as e: + code = int(e.code or 0) + if code != 0 and raise_exception: + if capture_output: + raise RuntimeError(buffer.getvalue()) from e + raise RuntimeError(f"Command {cmd} failed with exit code {code}.") from e + except Exception as e: + code = 1 + if raise_exception: + if capture_output: + raise RuntimeError(buffer.getvalue()) from e + raise RuntimeError(f"Command {cmd} failed: {e}") from e + if restart_cli and code == 0: + os.execv(original_exe, [original_exe] + original_args) + if capture_output: + return code, buffer.getvalue() + else: + return code, "" + finally: + sys.argv = original_args + + +def _build_command_tree(cmd) -> dict[str, Any] | None: + tree: dict[str, Any] | None = {} + assert isinstance(tree, dict) + if isinstance(cmd, click.Group): + for name, subcmd in cmd.commands.items(): + if cmd.name == "models" and name == "default": + tree[name] = {x.model_id: None for x in llm.get_models()} + else: + tree[name] = _build_command_tree(subcmd) + else: + tree = None + return tree + + +def build_command_tree(cmd) -> dict[str, Any]: + return _build_command_tree(cmd) or {} + + +# Generate the command tree for autocompletion +COMMAND_TREE = build_command_tree(cli) if LLM_CLI_IMPORTED is True else {} + + +def get_completions( + tokens: list[str], + tree: dict[str, Any] | None = None, +) -> list[str]: + tree = tree or COMMAND_TREE + for token in tokens: + if token.startswith("-"): + continue + if tree and token in tree: + tree = tree[token] + else: + return [] + return list(tree.keys()) if tree else [] + + +class FinishIteration(Exception): + def __init__(self, results=None): + self.results = results + + +USAGE = """ +Use an LLM to create SQL queries to answer questions from your database. +Examples: + +# Ask a question. +> \\llm 'Most visited urls?' + +# List available models +> \\llm models +> gpt-4o +> gpt-3.5-turbo + +# Change default model +> \\llm models default llama3 + +# Set api key (not required for local models) +> \\llm keys set openai + +# Install a model plugin +> \\llm install llm-ollama +> llm-ollama installed. + +# Plugins directory +# https://llm.datasette.io/en/stable/plugins/directory.html +""" + +NEED_DEPENDENCIES = """ +To enable LLM features you need to install mycli with LLM support: + + pip install 'mycli[llm]' + +or + + pip install 'mycli[all]' + +or install LLM libraries separately + + pip install llm + +This is required to use the \\llm command. +""" + +_SQL_CODE_FENCE = r"```sql\n(.*?)\n```" + +PROMPT = """ +You are a helpful assistant who is a MySQL expert. You are embedded in a mysql +cli tool called mycli. + +Answer this question: + +$question + +Use the following context if it is relevant to answering the question. If the +question is not about the current database then ignore the context. + +You are connected to a MySQL database with the following schema: + +$db_schema + +Here is a sample row of data from each table: + +$sample_data + +If the answer can be found using a SQL query, include a sql query in a code +fence such as this one: + +```sql +SELECT count(*) FROM table_name; +``` +Keep your explanation concise and focused on the question asked. +""" + + +def ensure_mycli_template(replace: bool = False) -> None: + if not replace: + code, _ = run_external_cmd("llm", "templates", "show", LLM_TEMPLATE_NAME, capture_output=True, raise_exception=False) + if code == 0: + return + run_external_cmd("llm", PROMPT, "--save", LLM_TEMPLATE_NAME) + + +@functools.cache +def cli_commands() -> list[str]: + return list(cli.commands.keys()) + + +def handle_llm( + text: str, + cur: Cursor, + dbname: str, + prompt_field_truncate: int, + prompt_section_truncate: int, +) -> tuple[str, str | None, float]: + _, command_verbosity, arg = parse_special_command(text) + if not LLM_IMPORTED: + raise FinishIteration(results=[SQLResult(preamble=NEED_DEPENDENCIES)]) + if arg.strip().lower() in ['', 'help', '?', r'\?']: + raise FinishIteration(results=[SQLResult(preamble=USAGE)]) + parts = shlex.split(arg) + restart = False + if "-c" in parts: + capture_output = True + use_context = False + elif "prompt" in parts: + capture_output = True + use_context = True + elif "install" in parts or "uninstall" in parts: + capture_output = False + use_context = False + restart = True + elif parts and parts[0] in cli_commands(): + capture_output = False + use_context = False + elif parts and parts[0] == "--help": + capture_output = False + use_context = False + else: + capture_output = True + use_context = True + if not use_context: + args = parts + if capture_output: + click.echo("Calling llm command") + start = time() + _, output = run_external_cmd("llm", *args, capture_output=capture_output) + end = time() + match = re.search(_SQL_CODE_FENCE, output, re.DOTALL) + if match: + sql = match.group(1).strip() + else: + raise FinishIteration(results=[SQLResult(preamble=output)]) + return (output if command_verbosity == CommandVerbosity.SUCCINCT else "", sql, end - start) + else: + run_external_cmd("llm", *args, restart_cli=restart) + raise FinishIteration(results=None) + try: + ensure_mycli_template() + start = time() + context, sql = sql_using_llm( + cur=cur, + question=arg, + dbname=dbname, + prompt_field_truncate=prompt_field_truncate, + prompt_section_truncate=prompt_section_truncate, + ) + end = time() + if command_verbosity == CommandVerbosity.SUCCINCT: + context = "" + return (context, sql, end - start) + except Exception as e: + raise RuntimeError(e) from e + + +def is_llm_command(command: str) -> bool: + cmd, _, _ = parse_special_command(command) + return cmd in ("\\llm", "\\ai") + + +def truncate_list_elements(row: list, prompt_field_truncate: int, prompt_section_truncate: int) -> list: + if not prompt_section_truncate and not prompt_field_truncate: + return row + + width = prompt_field_truncate + while width >= 0: + truncated_row = [x[:width] if isinstance(x, (str, bytes)) else x for x in row] + if prompt_section_truncate: + if sum(sys.getsizeof(x) for x in truncated_row) <= prompt_section_truncate: + break + width -= 100 + else: + break + return truncated_row + + +def truncate_table_lines(table: list[str], prompt_section_truncate: int) -> list[str]: + if not prompt_section_truncate: + return table + + truncated_table: list[str] = [] + running_sum = 0 + while table: + line = table.pop(0) + line_size = sys.getsizeof(line) + if running_sum + line_size > prompt_section_truncate: + break + running_sum += line_size + truncated_table.append(line) + return truncated_table + + +def get_schema(cur: Cursor, dbname: str, prompt_section_truncate: int) -> str: + if dbname in SCHEMA_DATA_CACHE: + return SCHEMA_DATA_CACHE[dbname] + click.echo("Preparing schema information to feed the LLM") + schema_query = f""" + SELECT CONCAT(table_name, '(', GROUP_CONCAT(column_name, ' ', COLUMN_TYPE SEPARATOR ', '),')') AS `schema` + FROM information_schema.columns + WHERE table_schema = '{dbname}' + GROUP BY table_name + ORDER BY table_name + """ + cur.execute(schema_query) + db_schema = [row for (row,) in cur.fetchall()] + summary = '\n'.join(truncate_table_lines(db_schema, prompt_section_truncate)) + SCHEMA_DATA_CACHE[dbname] = summary + return summary + + +def get_sample_data( + cur: Cursor, + dbname: str, + prompt_field_truncate: int, + prompt_section_truncate: int, +) -> dict[str, Any]: + if dbname in SAMPLE_DATA_CACHE: + return SAMPLE_DATA_CACHE[dbname] + click.echo("Preparing sample data to feed the LLM") + tables_query = "SHOW TABLES" + sample_row_query = "SELECT * FROM `{dbname}`.`{table}` LIMIT 1" + cur.execute(tables_query) + sample_data = {} + for (table_name,) in cur.fetchall(): + try: + cur.execute(sample_row_query.format(dbname=dbname, table=table_name)) + except Exception: + continue + cols = [desc[0] for desc in cur.description] + row = cur.fetchone() + if row is None: + continue + sample_data[table_name] = list( + zip(cols, truncate_list_elements(list(row), prompt_field_truncate, prompt_section_truncate), strict=False) + ) + SAMPLE_DATA_CACHE[dbname] = sample_data + return sample_data + + +def sql_using_llm( + cur: Cursor | None, + question: str | None, + dbname: str = '', + prompt_field_truncate: int = 0, + prompt_section_truncate: int = 0, +) -> tuple[str, str | None]: + if cur is None: + raise RuntimeError("Connect to a database and try again.") + if dbname == '': + raise RuntimeError("Choose a schema and try again.") + args = [ + "--template", + LLM_TEMPLATE_NAME, + "--param", + "db_schema", + get_schema(cur, dbname, prompt_section_truncate), + "--param", + "sample_data", + get_sample_data(cur, dbname, prompt_field_truncate, prompt_section_truncate), + "--param", + "question", + question, + " ", + ] + click.echo("Invoking llm command with schema information and sample data") + _, result = run_external_cmd("llm", *args, capture_output=True) + click.echo("Received response from the llm command") + match = re.search(_SQL_CODE_FENCE, result, re.DOTALL) + if match: + sql = match.group(1).strip() + else: + sql = "" + return (result, sql) diff --git a/mycli/packages/special/main.py b/mycli/packages/special/main.py index ab04f30d..12a6c7de 100644 --- a/mycli/packages/special/main.py +++ b/mycli/packages/special/main.py @@ -1,120 +1,305 @@ +from dataclasses import dataclass +from enum import Enum import logging -from collections import namedtuple +import os +from typing import Callable +import webbrowser -from . import export +from mycli.constants import DOCS_URL, ISSUES_URL +from mycli.packages.sqlresult import SQLResult -log = logging.getLogger(__name__) +try: + if not os.environ.get('MYCLI_LLM_OFF'): + import llm # noqa: F401 + + LLM_IMPORTED = True + else: + LLM_IMPORTED = False +except ImportError: + LLM_IMPORTED = False +from pymysql.cursors import Cursor + +logger = logging.getLogger(__name__) + +COMMANDS: dict[str, 'SpecialCommand'] = {} +CASE_SENSITIVE_COMMANDS: set[str] = set() +CASE_INSENSITIVE_COMMANDS: set[str] = set() + + +class ArgType(Enum): + NO_QUERY = 0 + PARSED_QUERY = 1 + RAW_QUERY = 2 + + +@dataclass(frozen=True) +class SpecialCommandAlias: + command: str + case_sensitive: bool -NO_QUERY = 0 -PARSED_QUERY = 1 -RAW_QUERY = 2 -SpecialCommand = namedtuple('SpecialCommand', - ['handler', 'command', 'shortcut', 'description', 'arg_type', 'hidden', - 'case_sensitive']) +@dataclass(frozen=True) +class SpecialCommand: + handler: Callable + command: str + usage: str + description: str + arg_type: ArgType + hidden: bool | None + case_sensitive: bool | None + aliases: list[SpecialCommandAlias] | None -COMMANDS = {} -@export class CommandNotFound(Exception): pass -@export -def parse_special_command(sql): - command, _, arg = sql.partition(' ') - verbose = '+' in command - command = command.strip().replace('+', '') - return (command, verbose, arg.strip()) -@export -def special_command(command, shortcut, description, arg_type=PARSED_QUERY, - hidden=False, case_sensitive=False, aliases=()): +class CommandVerbosity(Enum): + SUCCINCT = "succinct" + NORMAL = "normal" + VERBOSE = "verbose" + + +def parse_special_command(sql: str) -> tuple[str, CommandVerbosity, str]: + command, _, arg = sql.partition(" ") + command_verbosity = CommandVerbosity.NORMAL + if "+" in command: + command_verbosity = CommandVerbosity.VERBOSE + elif "-" in command: + command_verbosity = CommandVerbosity.SUCCINCT + command = command.strip().strip("+-") + return (command, command_verbosity, arg.strip()) + + +def special_command( + command: str, + usage: str, + description: str, + arg_type: ArgType = ArgType.PARSED_QUERY, + hidden: bool = False, + case_sensitive: bool = False, + aliases: list[SpecialCommandAlias] | None = None, +) -> Callable: def wrapper(wrapped): - register_special_command(wrapped, command, shortcut, description, - arg_type, hidden, case_sensitive, aliases) + register_special_command( + wrapped, + command, + usage, + description, + arg_type=arg_type, + hidden=hidden, + case_sensitive=case_sensitive, + aliases=aliases, + ) return wrapped + return wrapper -@export -def register_special_command(handler, command, shortcut, description, - arg_type=PARSED_QUERY, hidden=False, case_sensitive=False, aliases=()): + +def register_special_command( + handler: Callable, + command: str, + usage: str, + description: str, + arg_type: ArgType = ArgType.PARSED_QUERY, + hidden: bool = False, + case_sensitive: bool = False, + aliases: list[SpecialCommandAlias] | None = None, +) -> None: cmd = command.lower() if not case_sensitive else command - COMMANDS[cmd] = SpecialCommand(handler, command, shortcut, description, - arg_type, hidden, case_sensitive) + COMMANDS[cmd] = SpecialCommand( + handler, + command, + usage, + description, + arg_type=arg_type, + hidden=hidden, + case_sensitive=case_sensitive, + aliases=aliases, + ) + if case_sensitive: + CASE_SENSITIVE_COMMANDS.add(command) + else: + CASE_INSENSITIVE_COMMANDS.add(command.lower()) + aliases = [] if aliases is None else aliases for alias in aliases: - cmd = alias.lower() if not case_sensitive else alias - COMMANDS[cmd] = SpecialCommand(handler, command, shortcut, description, - arg_type, case_sensitive=case_sensitive, - hidden=True) + cmd = alias.command.lower() if not alias.case_sensitive else alias.command + if alias.case_sensitive: + CASE_SENSITIVE_COMMANDS.add(alias.command) + else: + CASE_INSENSITIVE_COMMANDS.add(alias.command.lower()) + COMMANDS[cmd] = SpecialCommand( + handler, + command, + usage, + description, + arg_type=arg_type, + case_sensitive=alias.case_sensitive, + hidden=True, + aliases=None, + ) + -@export -def execute(cur, sql): +def execute(cur: Cursor, sql: str) -> list[SQLResult]: """Execute a special command and return the results. If the special command - is not supported a KeyError will be raised. + is not supported a CommandNotFound will be raised. """ - command, verbose, arg = parse_special_command(sql) + command, command_verbosity, arg = parse_special_command(sql) - if (command not in COMMANDS) and (command.lower() not in COMMANDS): - raise CommandNotFound + if (command not in CASE_SENSITIVE_COMMANDS) and (command.lower() not in CASE_INSENSITIVE_COMMANDS): + raise CommandNotFound(f'Command not found: {command}') try: special_cmd = COMMANDS[command] - except KeyError: + except KeyError as exc: special_cmd = COMMANDS[command.lower()] if special_cmd.case_sensitive: - raise CommandNotFound('Command not found: %s' % command) + raise CommandNotFound(f'Command not found: {command}') from exc # "help is a special case. We want built-in help, not # mycli help here. - if command == 'help' and arg: + if command.lower() == "help" and arg: return show_keyword_help(cur=cur, arg=arg) - if special_cmd.arg_type == NO_QUERY: + if special_cmd.arg_type == ArgType.NO_QUERY: return special_cmd.handler() - elif special_cmd.arg_type == PARSED_QUERY: - return special_cmd.handler(cur=cur, arg=arg, verbose=verbose) - elif special_cmd.arg_type == RAW_QUERY: + elif special_cmd.arg_type == ArgType.PARSED_QUERY: + return special_cmd.handler(cur=cur, arg=arg, command_verbosity=(command_verbosity == CommandVerbosity.VERBOSE)) + elif special_cmd.arg_type == ArgType.RAW_QUERY: return special_cmd.handler(cur=cur, query=sql) -@special_command('help', '\\?', 'Show this help.', arg_type=NO_QUERY, aliases=('\\?', '?')) -def show_help(): # All the parameters are ignored. - headers = ['Command', 'Shortcut', 'Description'] + raise CommandNotFound(f"Command type not found: {command}") + + +@special_command( + "help", + "help [term]", + "Show this table, or search for help on a term.", + arg_type=ArgType.NO_QUERY, + aliases=[SpecialCommandAlias("\\?", case_sensitive=False), SpecialCommandAlias("?", case_sensitive=False)], +) +def show_help(*_args) -> list[SQLResult]: + header = ["Command", "Shortcut", "Usage", "Description"] result = [] - for _, value in sorted(COMMANDS.items()): - if not value.hidden: - result.append((value.command, value.shortcut, value.description)) - return [(None, result, headers, None)] + for _, value in sorted(COMMANDS.items(), key=lambda x: str.casefold(x[0])): + if value.hidden: + continue + if value.aliases: + shortcut = value.aliases[0].command + else: + shortcut = None + result.append((value.command, shortcut, value.usage, value.description)) + return [SQLResult(header=header, rows=result, postamble=f'Docs index — {DOCS_URL}')] + + +def _show_special_help(keyword: str) -> list[SQLResult]: + header = ['name', 'description', 'example'] + command = COMMANDS[keyword] + description = '\n'.join([command.usage or '', command.description]) + rows = [(keyword, description, '')] + return [SQLResult(header=header, rows=rows)] -def show_keyword_help(cur, arg): + +def _show_mysql_help(cur: Cursor, keyword: str) -> list[SQLResult]: """ - Call the built-in "show ", to display help for an SQL keyword. + Call the built-in "show ", to display help for an SQL keyword. :param cur: cursor :param arg: string :return: list """ - keyword = arg.strip('"').strip("'") - query = "help '{0}'".format(keyword) - log.debug(query) - cur.execute(query) + query = 'help %s' + logger.debug(query) + cur.execute(query, keyword) + if cur.description and cur.rowcount > 0: + header = [x[0] for x in cur.description] + return [SQLResult(header=header, rows=cur)] + logger.debug(query) + cur.execute(query, (f'%{keyword}%',)) if cur.description and cur.rowcount > 0: - headers = [x[0] for x in cur.description] - return [(None, cur, headers, '')] + header = [x[0] for x in cur.description] + return [SQLResult(preamble='Similar terms:', header=header, rows=cur)] else: - return [(None, None, None, 'No help found for {0}.'.format(keyword))] + return [SQLResult(status=f'No help found for "{keyword}".')] + + +def show_keyword_help(cur: Cursor, arg: str) -> list[SQLResult]: + keyword = arg.strip().strip('"').strip("'").rstrip('+-') + if keyword in CASE_SENSITIVE_COMMANDS: + return _show_special_help(keyword) + elif keyword.lower() in CASE_INSENSITIVE_COMMANDS: + return _show_special_help(keyword.lower()) -@special_command('exit', '\\q', 'Exit.', arg_type=NO_QUERY, aliases=('\\q', )) -@special_command('quit', '\\q', 'Quit.', arg_type=NO_QUERY) -def quit(*_args): + return _show_mysql_help(cur, keyword) + + +@special_command('\\bug', '\\bug', 'File a bug on GitHub.', arg_type=ArgType.NO_QUERY) +def file_bug(*_args) -> list[SQLResult]: + webbrowser.open_new_tab(ISSUES_URL) + return [SQLResult(status=f'{ISSUES_URL} — press "New Issue"')] + + +@special_command( + "exit", + "exit", + "Exit.", + arg_type=ArgType.NO_QUERY, + aliases=[SpecialCommandAlias("\\q", case_sensitive=False)], +) +@special_command( + "quit", + "quit", + "Quit.", + arg_type=ArgType.NO_QUERY, + aliases=[SpecialCommandAlias("\\q", case_sensitive=False)], +) +def quit_(*_args): raise EOFError -@special_command('\\e', '\\e', 'Edit command with editor (uses $EDITOR).', - arg_type=NO_QUERY, case_sensitive=True) -@special_command('\\clip', '\\clip', 'Copy query to the system clipboard.', - arg_type=NO_QUERY, case_sensitive=True) -@special_command('\\G', '\\G', 'Display current query results vertically.', - arg_type=NO_QUERY, case_sensitive=True) +@special_command( + "\\edit", + "\\edit | \\edit ", + "Edit query with editor (uses $VISUAL or $EDITOR).", + arg_type=ArgType.NO_QUERY, + case_sensitive=True, + aliases=[SpecialCommandAlias("\\e", case_sensitive=True)], +) +@special_command( + "\\clip", + "\\clip", + "Copy query to the system clipboard.", + arg_type=ArgType.NO_QUERY, + case_sensitive=True, +) +@special_command( + "\\G", + "\\G", + "Display query results vertically.", + arg_type=ArgType.NO_QUERY, + case_sensitive=True, +) +@special_command( + "\\g", + "\\g", + "Display query results (mnemonic: go).", + arg_type=ArgType.NO_QUERY, + case_sensitive=True, +) def stub(): raise NotImplementedError + + +if LLM_IMPORTED: + + @special_command( + "\\llm", + "\\llm [arguments]", + "Interrogate an LLM. See \"\\llm help\".", + arg_type=ArgType.RAW_QUERY, + case_sensitive=True, + aliases=[SpecialCommandAlias("\\ai", case_sensitive=True)], + ) + def llm_stub(): + raise NotImplementedError diff --git a/mycli/packages/special/utils.py b/mycli/packages/special/utils.py index ef96093a..fc014323 100644 --- a/mycli/packages/special/utils.py +++ b/mycli/packages/special/utils.py @@ -1,21 +1,33 @@ +import datetime +import logging import os -import subprocess +from typing import Any -def handle_cd_command(arg): +import click +import pymysql +from pymysql.cursors import Cursor + +logger = logging.getLogger(__name__) + +CACHED_SSL_VERSION: dict[tuple, str | None] = {} + + +def handle_cd_command(command: list[str]) -> tuple[bool, str | None]: """Handles a `cd` shell command by calling python's os.chdir.""" - CD_CMD = 'cd' - tokens = arg.split(CD_CMD + ' ') - directory = tokens[-1] if len(tokens) > 1 else None - if not directory: - return False, "No folder name was provided." + if not command[0].lower() == 'cd': + return False, 'Not a cd command.' + if len(command) != 2: + return False, 'Exactly one directory name must be provided.' + directory = command[1] try: os.chdir(directory) - subprocess.call(['pwd']) + click.echo(os.getcwd(), err=True) return True, None except OSError as e: return False, e.strerror -def format_uptime(uptime_in_seconds): + +def format_uptime(uptime_in_seconds: str) -> str: """Format number of seconds into human-readable string. :param uptime_in_seconds: The server uptime in seconds. @@ -30,17 +42,104 @@ def format_uptime(uptime_in_seconds): h, m = divmod(m, 60) d, h = divmod(h, 24) - uptime_values = [] + uptime_values: list[str] = [] - for value, unit in ((d, 'days'), (h, 'hours'), (m, 'min'), (s, 'sec')): + for value, unit in ((d, "days"), (h, "hours"), (m, "min"), (s, "sec")): if value == 0 and not uptime_values: # Don't include a value/unit if the unit isn't applicable to # the uptime. E.g. don't do 0 days 0 hours 1 min 30 sec. continue - elif value == 1 and unit.endswith('s'): + if value == 1 and unit.endswith("s"): # Remove the "s" if the unit is singular. unit = unit[:-1] - uptime_values.append('{0} {1}'.format(value, unit)) + uptime_values.append(f'{value} {unit}') - uptime = ' '.join(uptime_values) + uptime = " ".join(uptime_values) return uptime + + +def get_uptime(cur: Cursor) -> int: + query = 'SHOW STATUS LIKE "Uptime"' + logger.debug(query) + + uptime = 0 + + try: + cur.execute(query) + if one := cur.fetchone(): + uptime = int(one[1] or 0) + except pymysql.err.OperationalError: + pass + + return uptime + + +def get_warning_count(cur: Cursor) -> int: + query = 'SHOW COUNT(*) WARNINGS' + logger.debug(query) + + warning_count = 0 + + try: + cur.execute(query) + if one := cur.fetchone(): + warning_count = int(one[0] or 0) + except pymysql.err.OperationalError: + pass + + return warning_count + + +def get_ssl_version(cur: Cursor) -> str | None: + cache_key = (id(cur.connection), cur.connection.thread_id()) + + if cache_key in CACHED_SSL_VERSION: + return CACHED_SSL_VERSION[cache_key] or None + + query = 'SHOW STATUS LIKE "Ssl_version"' + logger.debug(query) + + ssl_version = None + + try: + cur.execute(query) + if one := cur.fetchone(): + CACHED_SSL_VERSION[cache_key] = one[1] + ssl_version = one[1] or None + else: + CACHED_SSL_VERSION[cache_key] = '' + except pymysql.err.OperationalError: + pass + + return ssl_version + + +def get_ssl_cipher(cur: Cursor) -> str | None: + query = 'SHOW STATUS LIKE "Ssl_cipher"' + logger.debug(query) + + ssl_cipher = None + + try: + cur.execute(query) + if one := cur.fetchone(): + ssl_cipher = one[1] or None + except pymysql.err.OperationalError: + pass + + return ssl_cipher + + +def get_server_timezone(variables: dict[str, Any]) -> str: + try: + if variables['time_zone'] == 'SYSTEM': + server_tz = variables['system_time_zone'] + else: + server_tz = variables['time_zone'] + return server_tz + except KeyError: + return '' + + +def get_local_timezone() -> str: + return datetime.datetime.now().astimezone().tzname() or '' diff --git a/mycli/packages/sql_utils.py b/mycli/packages/sql_utils.py new file mode 100644 index 00000000..c03d5c85 --- /dev/null +++ b/mycli/packages/sql_utils.py @@ -0,0 +1,558 @@ +from __future__ import annotations + +import re +from typing import Any, Generator, Literal + +import sqlglot +import sqlglot.tokens +import sqlparse +from sqlparse.sql import Function, Identifier, IdentifierList, Token, TokenList +from sqlparse.tokens import DML, Keyword, Punctuation + +sqlparse.engine.grouping.MAX_GROUPING_DEPTH = None # type: ignore[assignment] +sqlparse.engine.grouping.MAX_GROUPING_TOKENS = None # type: ignore[assignment] + +cleanup_regex: dict[str, re.Pattern] = { + # This matches only alphanumerics and underscores. + "alphanum_underscore": re.compile(r"(\w+)$"), + # This matches everything except spaces, parens, colon, and comma + "many_punctuations": re.compile(r"([^():,\s]+)$"), + # This matches everything except spaces, parens, colon, comma, and period + "most_punctuations": re.compile(r"([^\.():,\s]+)$"), + # This matches everything except a space. + "all_punctuations": re.compile(r"([^\s]+)$"), +} + + +def last_word( + text: str, + include: Literal[ + 'alphanum_underscore', + 'many_punctuations', + 'most_punctuations', + 'all_punctuations', + ] = 'alphanum_underscore', +) -> str: + r""" + Find the last word in a sentence. + + >>> last_word('abc') + 'abc' + >>> last_word(' abc') + 'abc' + >>> last_word('') + '' + >>> last_word(' ') + '' + >>> last_word('abc ') + '' + >>> last_word('abc def') + 'def' + >>> last_word('abc def ') + '' + >>> last_word('abc def;') + '' + >>> last_word('bac $def') + 'def' + >>> last_word('bac $def', include='most_punctuations') + '$def' + >>> last_word('bac \def', include='most_punctuations') + '\\\\def' + >>> last_word('bac \def;', include='most_punctuations') + '\\\\def;' + >>> last_word('bac::def', include='most_punctuations') + 'def' + """ + + if not text: # Empty string + return "" + + if text[-1].isspace(): + return "" + else: + regex = cleanup_regex[include] + matches = regex.search(text) + if matches: + return matches.group(0) + else: + return "" + + +# This code is borrowed from sqlparse example script. +# +def is_subselect(parsed: TokenList) -> bool: + if not parsed.is_group: + return False + for item in parsed.tokens: + if item.ttype is DML and item.value.upper() in ("SELECT", "INSERT", "UPDATE", "CREATE", "DELETE"): + return True + return False + + +def get_last_select(parsed: TokenList) -> TokenList: + """ + Takes a parsed sql statement and returns the last select query where applicable. + + The intended use case is for when giving table suggestions based on columns, where + we only want to look at the columns from the most recent select. This works for a single + select query, or one or more sub queries (the useful part). + + The custom logic is necessary because the typical sqlparse logic for things like finding + sub selects (i.e. is_subselect) only works on complete statements, such as: + + * select c1 from t1; + + However when suggesting tables based on columns, we only have partial select statements, i.e.: + + * select c1 + * select c1 from (select c2) + + So given the above, we must parse them ourselves as they are not viewed as complete statements. + + Returns a TokenList of the last select statement's tokens. + """ + select_indexes: list[int] = [] + + for token in parsed: + if token.match(DML, "select"): # match is case insensitive + select_indexes.append(parsed.token_index(token)) + + last_select = TokenList() + + if select_indexes: + last_select = TokenList(parsed[select_indexes[-1] :]) + + return last_select + + +def extract_from_part(parsed: TokenList, stop_at_punctuation: bool = True) -> Generator[Any, None, None]: + tbl_prefix_seen = False + for item in parsed.tokens: + if tbl_prefix_seen: + if is_subselect(item): + yield from extract_from_part(item, stop_at_punctuation) + elif stop_at_punctuation and item.ttype is Punctuation: + return None + # Multiple JOINs in the same query won't work properly since + # "ON" is a keyword and will trigger the next elif condition. + # So instead of stooping the loop when finding an "ON" skip it + # eg: 'SELECT * FROM abc JOIN def ON abc.id = def.abc_id JOIN ghi' + elif item.ttype is Keyword and item.value.upper() == "ON": + tbl_prefix_seen = False + continue + # An incomplete nested select won't be recognized correctly as a + # sub-select. eg: 'SELECT * FROM (SELECT id FROM user'. This causes + # the second FROM to trigger this elif condition resulting in a + # StopIteration. So we need to ignore the keyword if the keyword + # FROM. + # Also 'SELECT * FROM abc JOIN def' will trigger this elif + # condition. So we need to ignore the keyword JOIN and its variants + # INNER JOIN, FULL OUTER JOIN, etc. + elif item.ttype is Keyword and (not item.value.upper() == "FROM") and (not item.value.upper().endswith("JOIN")): + return None + else: + yield item + elif (item.ttype is Keyword or item.ttype is Keyword.DML) and item.value.upper() in ( + "COPY", + "FROM", + "INTO", + "UPDATE", + "TABLE", + "JOIN", + ): + tbl_prefix_seen = True + # 'SELECT a, FROM abc' will detect FROM as part of the column list. + # So this check here is necessary. + elif isinstance(item, IdentifierList): + for identifier in item.get_identifiers(): + if identifier.ttype is Keyword and identifier.value.upper() == "FROM": + tbl_prefix_seen = True + break + + +def extract_table_identifiers(token_stream: Generator[Any, None, None]) -> Generator[tuple[str | None, str, str], None, None]: + """yields tuples of (schema_name, table_name, table_alias)""" + + for item in token_stream: + if isinstance(item, IdentifierList): + for identifier in item.get_identifiers(): + # Sometimes Keywords (such as FROM ) are classified as + # identifiers which don't have the get_real_name() method. + try: + schema_name = identifier.get_parent_name() + real_name = identifier.get_real_name() + except AttributeError: + continue + if real_name: + yield (schema_name, real_name, identifier.get_alias()) + elif isinstance(item, Identifier): + real_name = item.get_real_name() + schema_name = item.get_parent_name() + + if real_name: + yield (schema_name, real_name, item.get_alias()) + else: + name = item.get_name() + yield (None, name, item.get_alias() or name) + elif isinstance(item, Function): + yield (None, item.get_name(), item.get_name()) + + +# extract_tables is inspired from examples in the sqlparse lib. +def extract_tables(sql: str) -> list[tuple[str | None, str, str]]: + """Extract the table names from an SQL statement. + + Returns a list of (schema, table, alias) tuples + + """ + parsed = sqlparse.parse(sql) + if not parsed: + return [] + + # INSERT statements must stop looking for tables at the sign of first + # Punctuation. eg: INSERT INTO abc (col1, col2) VALUES (1, 2) + # abc is the table name, but if we don't stop at the first lparen, then + # we'll identify abc, col1 and col2 as table names. + insert_stmt = parsed[0].token_first().value.lower() == "insert" + stream = extract_from_part(parsed[0], stop_at_punctuation=insert_stmt) + return list(extract_table_identifiers(stream)) + + +def extract_columns_from_select(sql: str) -> list[str]: + """ + Extract the column names from a select SQL statement. + + Returns a list of columns. + """ + parsed = sqlparse.parse(sql) + if not parsed: + return [] + + statement = get_last_select(parsed[0]) + + # if there is no select, skip checking for columns + if not statement: + return [] + + columns = [] + + # Loops through the tokens (pieces) of the SQL statement. + # Once it finds the SELECT token (generally first), it + # will then start looking for columns from that point on. + # The get_real_name() function returns the real column name + # even if an alias is used. + found_select = False + for token in statement.tokens: + if token.ttype is DML and token.value.upper() == 'SELECT': + found_select = True + elif found_select: + if isinstance(token, IdentifierList): + # multiple columns + for identifier in token.get_identifiers(): + if isinstance(identifier, Identifier): + column = identifier.get_real_name() + elif isinstance(identifier, Token): + column = identifier.value + else: + continue + columns.append(column) + elif isinstance(token, Identifier): + # single column + column = token.get_real_name() + columns.append(column) + elif token.ttype is Keyword: + break + + if columns: + break + return columns + + +def extract_tables_from_complete_statements(sql: str) -> list[tuple[str | None, str, str | None]]: + """Extract the table names from a complete and valid series of SQL + statements. + + Returns a list of (schema, table, alias) tuples + + """ + # sqlglot chokes entirely on things like "\T" that it doesn't know about, + # but is much better at extracting table names from complete statements. + # sqlparse can extract the series of statements, though it also doesn't + # understand "\T". + roughly_parsed = sqlparse.parse(sql) + if not roughly_parsed: + return [] + + finely_parsed = [] + for rough_statement in roughly_parsed: + try: + finely_parsed.append(sqlglot.parse_one(str(rough_statement), read='mysql')) + except sqlglot.errors.ParseError: + pass + + tables = [] + for fine_statement in finely_parsed: + for identifier in fine_statement.find_all(sqlglot.exp.Table): + if identifier.parent_select and identifier.parent_select.sql().startswith('WITH'): + continue + tables.append(( + None if identifier.db == '' else identifier.db, + identifier.name, + None if identifier.alias == '' else identifier.alias, + )) + + return tables + + +def find_prev_keyword(sql: str) -> tuple[Token | None, str]: + """Find the last sql keyword in an SQL statement + + Returns the value of the last keyword, and the text of the query with + everything after the last keyword stripped + """ + if not sql.strip(): + return None, "" + + parsed = sqlparse.parse(sql)[0] + flattened = list(parsed.flatten()) + + logical_operators = ("AND", "OR", "NOT", "BETWEEN") + + for t in reversed(flattened): + if t.value == "(" or (t.is_keyword and (t.value.upper() not in logical_operators)): + # Find the location of token t in the original parsed statement + # We can't use parsed.token_index(t) because t may be a child token + # inside a TokenList, in which case token_index thows an error + # Minimal example: + # p = sqlparse.parse('select * from foo where bar') + # t = list(p.flatten())[-3] # The "Where" token + # p.token_index(t) # Throws ValueError: not in list + idx = flattened.index(t) + + # Combine the string values of all tokens in the original list + # up to and including the target keyword token t, to produce a + # query string with everything after the keyword token removed + text = "".join(tok.value for tok in flattened[: idx + 1]) + return t, text + + return None, "" + + +def query_starts_with(query: str, prefixes: list[str]) -> bool: + """Check if the query starts with any item from *prefixes*.""" + prefixes = [prefix.lower() for prefix in prefixes] + formatted_sql = sqlparse.format(query.lower(), strip_comments=True) + return bool(formatted_sql) and formatted_sql.split()[0] in prefixes + + +def queries_start_with(queries: str, prefixes: list[str]) -> bool: + """Check if any queries start with any item from *prefixes*.""" + for query in sqlparse.split(queries): + if query and query_starts_with(query, prefixes) is True: + return True + return False + + +def query_has_where_clause(query: str) -> bool: + """Check if the query contains a where-clause.""" + return any(isinstance(token, sqlparse.sql.Where) for token_list in sqlparse.parse(query) for token in token_list) + + +# todo: handle "UPDATE LOW_PRIORITY" and "UPDATE IGNORE" +def query_is_single_table_update(query: str) -> bool: + """Check if a query is a simple single-table UPDATE.""" + cleaned_query_for_parsing_only = sqlparse.format(query, strip_comments=True) + cleaned_query_for_parsing_only = re.sub(r'\s+', ' ', cleaned_query_for_parsing_only) + if not cleaned_query_for_parsing_only: + return False + parsed = sqlparse.parse(cleaned_query_for_parsing_only) + if not parsed: + return False + statement = parsed[0] + try: + retval = bool( + statement[0].value.lower() == 'update' + and statement[1].is_whitespace + and ',' not in statement[2].value # multiple tables + and statement[3].is_whitespace + and statement[4].value.lower() == 'set' + ) + except IndexError: + retval = False + + return retval + + +def is_destructive(keywords: list[str], queries: str) -> bool: + """Returns True if any of the queries in *queries* is destructive.""" + for query in sqlparse.split(queries): + if not query: + continue + # subtle: if "UPDATE" is one of our keywords AND "query" starts with "UPDATE" + if query_starts_with(query, keywords) and query_starts_with(query, ["update"]): + if query_has_where_clause(query) and query_is_single_table_update(query): + return False + else: + return True + if query_starts_with(query, keywords): + return True + + return False + + +def is_dropping_database(queries: str, dbname: str | None) -> bool: + """Determine if the query is dropping a specific database.""" + result = False + if dbname is None: + return False + + def normalize_db_name(db: str) -> str: + return db.lower().strip('`"') + + dbname = normalize_db_name(dbname) + + for query in sqlparse.parse(queries): + keywords = [t for t in query.tokens if t.is_keyword] + if len(keywords) < 2: + continue + if keywords[0].normalized in ("DROP", "CREATE") and keywords[1].value.lower() in ( + "database", + "schema", + ): + database_token = next((t for t in query.tokens if isinstance(t, Identifier)), None) + if database_token is not None and normalize_db_name(database_token.get_name()) == dbname: + result = keywords[0].normalized == "DROP" + return result + + +def need_completion_refresh(queries: str) -> bool: + """Determines if the completion needs a refresh by checking if the sql + statement is an alter, create, drop or change db.""" + for query in sqlparse.split(queries): + try: + first_token = query.split()[0] + if first_token.lower() in ("alter", "create", "use", "\\r", "\\u", "connect", "drop", "rename"): + return True + except Exception: + continue + return False + + +def need_completion_reset(queries: str) -> bool: + """Determines if the statement is a database switch such as 'use' or '\\u'. + When a database is changed the existing completions must be reset before we + start the completion refresh for the new database. + """ + for query in sqlparse.split(queries): + try: + tokens = query.split() + first_token = tokens[0] + if first_token.lower() in ("use", "\\u"): + return True + if first_token.lower() in ("\\r", "connect") and len(tokens) > 1: + return True + except Exception: + continue + return False + + +def is_mutating(status_plain: str | None) -> bool: + """Determines if the statement is mutating based on the status.""" + if not status_plain: + return False + + mutating = {"insert", "update", "delete", "alter", "create", "drop", "replace", "truncate", "load", "rename"} + return status_plain.split(None, 1)[0].lower() in mutating + + +def is_select(status_plain: str | None) -> bool: + """Returns true if the first word in status is 'select'.""" + if not status_plain: + return False + return status_plain.split(None, 1)[0].lower() == "select" + + +def classify_sandbox_statement(text: str) -> tuple[str | None, str | None]: + """Classify a SQL statement for sandbox mode and extract the new password. + + Returns (statement_type, new_password) where statement_type is one of: + - 'alter_user' — ALTER USER ... IDENTIFIED BY ... + - 'set_password' — SET PASSWORD [FOR ...] = ... + - 'quit' — quit, exit, \\q + - None — not allowed in sandbox mode + """ + stripped = text.strip() + if not stripped: + return ('quit', None) + + try: + tokens = list(sqlglot.tokenize(stripped, dialect='mysql')) + except sqlglot.errors.TokenError: + tokens = [] + + if not tokens: + return ('quit', None) + + types = [t.token_type for t in tokens] + texts = [t.text.upper() for t in tokens] + tt = sqlglot.tokens.TokenType + + # quit, exit + if len(tokens) == 1 and types[0] == tt.VAR and texts[0] in ('QUIT', 'EXIT'): + return ('quit', None) + + # \q + if len(tokens) == 2 and types[0] == tt.BACKSLASH and texts[1] == 'Q': + return ('quit', None) + + # ALTER USER ... + if len(tokens) >= 2 and types[0] == tt.ALTER and texts[1] == 'USER': + pw = _find_password_after_by(tokens) + return ('alter_user', pw) + + # SET PASSWORD ... + if len(tokens) >= 2 and types[0] == tt.SET and texts[1] == 'PASSWORD': + pw = _find_password_after_eq(tokens) + return ('set_password', pw) + + return (None, None) + + +def _find_password_after_by(tokens: list[sqlglot.tokens.Token]) -> str | None: + """Find a password literal following a BY token (for ALTER USER ... IDENTIFIED BY 'pw').""" + tt = sqlglot.tokens.TokenType + for i, tok in enumerate(tokens): + if tok.token_type == tt.VAR and tok.text.upper() == 'BY' and i + 1 < len(tokens): + next_tok = tokens[i + 1] + if next_tok.token_type == tt.STRING: + return next_tok.text + return None + + +def _find_password_after_eq(tokens: list[sqlglot.tokens.Token]) -> str | None: + """Find a password literal following an = token (for SET PASSWORD = 'pw').""" + tt = sqlglot.tokens.TokenType + for i, tok in enumerate(tokens): + if tok.token_type == tt.EQ and i + 1 < len(tokens): + next_tok = tokens[i + 1] + if next_tok.token_type == tt.STRING: + return next_tok.text + return None + + +def is_sandbox_allowed(text: str) -> bool: + """Return True if the command is allowed in expired-password sandbox mode.""" + stmt_type, _ = classify_sandbox_statement(text) + return stmt_type is not None + + +def is_password_change(text: str) -> bool: + """Return True if the command is a password change statement.""" + stmt_type, _ = classify_sandbox_statement(text) + return stmt_type in ('alter_user', 'set_password') + + +def extract_new_password(text: str) -> str | None: + """Extract the new password from an ALTER USER or SET PASSWORD statement.""" + _, password = classify_sandbox_statement(text) + return password diff --git a/mycli/packages/sqlresult.py b/mycli/packages/sqlresult.py new file mode 100644 index 00000000..b1f5e272 --- /dev/null +++ b/mycli/packages/sqlresult.py @@ -0,0 +1,24 @@ +from dataclasses import dataclass +from functools import cached_property + +from prompt_toolkit.formatted_text import FormattedText, to_plain_text +from pymysql.cursors import Cursor + + +@dataclass +class SQLResult: + preamble: str | None = None + header: list[str] | str | None = None + rows: Cursor | list[tuple] | None = None + postamble: str | None = None + status: str | FormattedText | None = None + command: dict[str, str | float] | None = None + + def __str__(self): + return f"{self.preamble}, {self.header}, {self.rows}, {self.postamble}, {self.status}, {self.command}" + + @cached_property + def status_plain(self): + if self.status is None: + return None + return to_plain_text(self.status) diff --git a/mycli/packages/ssh_utils.py b/mycli/packages/ssh_utils.py new file mode 100644 index 00000000..1b81384a --- /dev/null +++ b/mycli/packages/ssh_utils.py @@ -0,0 +1,27 @@ +import sys + +import click + +try: + import paramiko +except ImportError: + from mycli.packages.paramiko_stub import paramiko # type: ignore[no-redef] + + +# it isn't cool that this utility function can exit(), but it is slated to be removed anyway +def read_ssh_config(ssh_config_path: str): + ssh_config = paramiko.config.SSHConfig() + try: + with open(ssh_config_path) as f: + ssh_config.parse(f) + except FileNotFoundError as e: + click.secho(str(e), err=True, fg="red") + sys.exit(1) + # Paramiko prior to version 2.7 raises Exception on parse errors. + # In 2.7 it has become paramiko.ssh_exception.SSHException, + # but let's catch everything for compatibility + except Exception as err: + click.secho(f"Could not parse SSH configuration file {ssh_config_path}:\n{err} ", err=True, fg="red") + sys.exit(1) + else: + return ssh_config diff --git a/mycli/packages/string_utils.py b/mycli/packages/string_utils.py new file mode 100644 index 00000000..56103330 --- /dev/null +++ b/mycli/packages/string_utils.py @@ -0,0 +1,15 @@ +import re + +from cli_helpers.utils import strip_ansi +from prompt_toolkit.formatted_text import ( + FormattedText, + to_plain_text, +) + + +def sanitize_terminal_title(title: FormattedText) -> str: + sanitized = to_plain_text(title) + sanitized = strip_ansi(sanitized) + sanitized = sanitized.replace('\n', ' ') + sanitized = re.sub('[\x00-\x1f\x7f]', '', sanitized) + return sanitized diff --git a/mycli/packages/tabular_output/sql_format.py b/mycli/packages/tabular_output/sql_format.py index e6587bd3..31def8e1 100644 --- a/mycli/packages/tabular_output/sql_format.py +++ b/mycli/packages/tabular_output/sql_format.py @@ -1,62 +1,71 @@ """Format adapter for sql.""" -from mycli.packages.parseutils import extract_tables +from __future__ import annotations -supported_formats = ('sql-insert', 'sql-update', 'sql-update-1', - 'sql-update-2', ) +from typing import Generator, Union + +from cli_helpers.tabular_output import TabularOutputFormatter + +from mycli.packages.sql_utils import extract_tables_from_complete_statements + +supported_formats = ( + "sql-insert", + "sql-update", + "sql-update-1", + "sql-update-2", +) preprocessors = () +formatter: TabularOutputFormatter + -def escape_for_sql_statement(value): +def escape_for_sql_statement(value: Union[bytes, str]) -> str: if isinstance(value, bytes): - return f"X'{value.hex()}'" + return f"0x{value.hex()}" else: return formatter.mycli.sqlexecute.conn.escape(value) -def adapter(data, headers, table_format=None, **kwargs): - tables = extract_tables(formatter.query) +def adapter(data: list[str], headers: list[str], table_format: Union[str, None] = None, **kwargs) -> Generator[str, None, None]: + tables = extract_tables_from_complete_statements(formatter.query) if len(tables) > 0: table = tables[0] if table[0]: - table_name = "{}.{}".format(*table[:2]) + table_name = f'{table[0]}.{table[1]}' else: table_name = table[1] else: table_name = "`DUAL`" - if table_format == 'sql-insert': + if table_format == "sql-insert": h = "`, `".join(headers) - yield "INSERT INTO {} (`{}`) VALUES".format(table_name, h) + yield f'INSERT INTO {table_name} (`{h}`) VALUES' prefix = " " for d in data: - values = ", ".join(escape_for_sql_statement(v) - for i, v in enumerate(d)) - yield "{}({})".format(prefix, values) + values = ", ".join(escape_for_sql_statement(v) for i, v in enumerate(d)) + yield f'{prefix}({values})' if prefix == " ": prefix = ", " yield ";" - if table_format.startswith('sql-update'): - s = table_format.split('-') + if table_format and table_format.startswith("sql-update"): + s = table_format.split("-") keys = 1 if len(s) > 2: keys = int(s[-1]) for d in data: - yield "UPDATE {} SET".format(table_name) + yield f'UPDATE {table_name} SET' prefix = " " for i, v in enumerate(d[keys:], keys): - yield "{}`{}` = {}".format(prefix, headers[i], escape_for_sql_statement(v)) + yield f'{prefix}`{headers[i]}` = {escape_for_sql_statement(v)}' if prefix == " ": prefix = ", " f = "`{}` = {}" - where = (f.format(headers[i], escape_for_sql_statement( - d[i])) for i in range(keys)) - yield "WHERE {};".format(" AND ".join(where)) + where = (f.format(headers[i], escape_for_sql_statement(d[i])) for i in range(keys)) + yield f'WHERE {" AND ".join(where)};' -def register_new_formatter(TabularOutputFormatter): +def register_new_formatter(tof: TabularOutputFormatter): global formatter - formatter = TabularOutputFormatter + formatter = tof for sql_format in supported_formats: - TabularOutputFormatter.register_new_formatter( - sql_format, adapter, preprocessors, {'table_format': sql_format}) + tof.register_new_formatter(sql_format, adapter, preprocessors, {"table_format": sql_format}) diff --git a/mycli/packages/toolkit/fzf.py b/mycli/packages/toolkit/fzf.py deleted file mode 100644 index 36cb347a..00000000 --- a/mycli/packages/toolkit/fzf.py +++ /dev/null @@ -1,45 +0,0 @@ -from shutil import which - -from pyfzf import FzfPrompt -from prompt_toolkit import search -from prompt_toolkit.key_binding.key_processor import KeyPressEvent - -from .history import FileHistoryWithTimestamp - - -class Fzf(FzfPrompt): - def __init__(self): - self.executable = which("fzf") - if self.executable: - super().__init__() - - def is_available(self) -> bool: - return self.executable is not None - - -def search_history(event: KeyPressEvent): - buffer = event.current_buffer - history = buffer.history - - fzf = Fzf() - - if fzf.is_available() and isinstance(history, FileHistoryWithTimestamp): - history_items_with_timestamp = history.load_history_with_timestamp() - - formatted_history_items = [] - original_history_items = [] - for item, timestamp in history_items_with_timestamp: - formatted_item = item.replace('\n', ' ') - timestamp = timestamp.split(".")[0] if "." in timestamp else timestamp - formatted_history_items.append(f"{timestamp} {formatted_item}") - original_history_items.append(item) - - result = fzf.prompt(formatted_history_items, fzf_options="--tiebreak=index") - - if result: - selected_index = formatted_history_items.index(result[0]) - buffer.text = original_history_items[selected_index] - buffer.cursor_position = len(buffer.text) - else: - # Fallback to default reverse incremental search - search.start_search(direction=search.SearchDirection.BACKWARD) diff --git a/mycli/schema_prefetcher.py b/mycli/schema_prefetcher.py new file mode 100644 index 00000000..25467598 --- /dev/null +++ b/mycli/schema_prefetcher.py @@ -0,0 +1,241 @@ +"""Background prefetcher for multi-schema auto-completion. + +The default completion refresher only populates metadata for the +currently-selected schema. ``SchemaPrefetcher`` loads metadata for +additional schemas on a background thread so that users can get +qualified auto-completion suggestions (``OtherSchema.table``) without +switching databases first. +""" + +from __future__ import annotations + +from enum import Enum +import logging +import threading +from typing import TYPE_CHECKING, Any, Iterable + +from mycli.sqlexecute import SQLExecute + +if TYPE_CHECKING: # pragma: no cover - typing only + from mycli.main import MyCli + from mycli.sqlcompleter import SQLCompleter + +_logger = logging.getLogger(__name__) + + +class PrefetchMode(str, Enum): + ALWAYS = 'always' + NEVER = 'never' + LISTED = 'listed' + + +def parse_prefetch_config(mode: str, schema_list: list[str]) -> list[str] | None: + """Parse the ``prefetch_schemas_mode`` / ``prefetch_schemas_list`` options. + + Returns ``None`` when every accessible schema should be prefetched + (``always``), an empty list when prefetching is disabled + (``never``), or ``schema_list`` when the mode is ``listed``. + Unknown modes fall back to ``always``. + """ + try: + parsed = PrefetchMode(mode.strip().lower()) + except ValueError: + return None + if parsed is PrefetchMode.NEVER: + return [] + if parsed is PrefetchMode.LISTED: + return schema_list + return None + + +class SchemaPrefetcher: + """Run schema prefetch work on a dedicated background thread.""" + + def __init__(self, mycli: 'MyCli') -> None: + self.mycli = mycli + self._thread: threading.Thread | None = None + self._cancel = threading.Event() + self._loaded: set[str] = set() + + def is_prefetching(self) -> bool: + return bool(self._thread and self._thread.is_alive()) + + def clear_loaded(self) -> None: + """Forget which schemas have been prefetched (used on reset).""" + self._loaded.clear() + + def stop(self, timeout: float = 2.0) -> None: + """Signal the background thread to stop and wait briefly for it.""" + if self._thread and self._thread.is_alive(): + self._cancel.set() + self._thread.join(timeout=timeout) + self._cancel = threading.Event() + self._thread = None + + def start_configured(self) -> None: + """Start prefetching based on the user's prefetch settings.""" + mode = getattr(self.mycli, 'prefetch_schemas_mode', PrefetchMode.ALWAYS.value) + schema_list = getattr(self.mycli, 'prefetch_schemas_list', []) + parsed = parse_prefetch_config(mode, schema_list) + if parsed is not None and not parsed: + # ``never`` or ``listed`` with an empty list — nothing to do. + return + self._start(parsed) + + def prefetch_schema_now(self, schema: str) -> None: + """Fetch *schema* immediately on a background thread. + + Used when a user manually switches to a schema. The method + returns quickly; the actual work happens in the new thread. + """ + if not schema: + return + # Avoid double-fetching while a full-prefetch pass is running. + self.stop() + self._start([schema]) + + def _start(self, schemas: Iterable[str] | None) -> None: + """Spawn the background worker. + + ``schemas=None`` defers resolution to the worker, which lists + every database via its own dedicated connection — the main + thread's ``sqlexecute`` must not be used here since the worker + would race with the REPL. + """ + self.stop() + queue: list[str] | None = None if schemas is None else list(schemas) + self._cancel = threading.Event() + self._thread = threading.Thread( + target=self._run, + args=(queue,), + name='schema_prefetcher', + daemon=True, + ) + self._thread.start() + self._invalidate_app() + + def _run(self, schemas: list[str] | None) -> None: + executor: SQLExecute | None = None + try: + executor = self._make_executor() + except Exception as e: # pragma: no cover - defensive + _logger.error('schema prefetch could not open connection: %r', e) + self._invalidate_app() + return + try: + if schemas is None: + try: + schemas = list(executor.databases()) + except Exception as e: + _logger.error('failed to list databases for prefetch: %r', e) + return + current = self._current_schema() + existing = set(self.mycli.completer.dbmetadata.get('tables', {}).keys()) + queue = [s for s in schemas if s and s != current and s not in self._loaded and s not in existing] + for schema in queue: + if self._cancel.is_set(): + return + try: + self._prefetch_one(executor, schema) + self._loaded.add(schema) + except Exception as e: + _logger.error('prefetch failed for schema %r: %r', schema, e) + finally: + try: + executor.close() + except Exception: # pragma: no cover - defensive + pass + self._invalidate_app() + + def _prefetch_one(self, executor: SQLExecute, schema: str) -> None: + _logger.debug('prefetching schema %r', schema) + table_rows = list(executor.table_columns(schema=schema)) + fk_rows = list(executor.foreign_keys(schema=schema)) + enum_rows = list(executor.enum_values(schema=schema)) + func_rows = list(executor.functions(schema=schema)) + proc_rows = list(executor.procedures(schema=schema)) + + # Use the live completer's escape logic so keys match what the + # completion engine computes when parsing user input. + completer = self.mycli.completer + table_columns: dict[str, list[str]] = {} + for table, column in table_rows: + esc_table = completer.escape_name(table) + esc_col = completer.escape_name(column) + cols = table_columns.setdefault(esc_table, ['*']) + cols.append(esc_col) + + fk_tables: dict[str, set[str]] = {} + fk_relations: list[tuple[str, str, str, str]] = [] + for table, col, ref_table, ref_col in fk_rows: + esc_table = completer.escape_name(table) + esc_col = completer.escape_name(col) + esc_ref_table = completer.escape_name(ref_table) + esc_ref_col = completer.escape_name(ref_col) + fk_tables.setdefault(esc_table, set()).add(esc_ref_table) + fk_tables.setdefault(esc_ref_table, set()).add(esc_table) + fk_relations.append((esc_table, esc_col, esc_ref_table, esc_ref_col)) + fk_payload: dict[str, Any] = {'tables': fk_tables, 'relations': fk_relations} + + enum_values: dict[str, dict[str, list[str]]] = {} + for table, column, values in enum_rows: + esc_table = completer.escape_name(table) + esc_col = completer.escape_name(column) + enum_values.setdefault(esc_table, {})[esc_col] = list(values) + + functions: dict[str, None] = {} + for row in func_rows: + if not row or not row[0]: + continue + functions[completer.escape_name(row[0])] = None + + procedures: dict[str, None] = {} + for row in proc_rows: + if not row or not row[0]: + continue + procedures[completer.escape_name(row[0])] = None + + with self.mycli._completer_lock: + live_completer: 'SQLCompleter' = self.mycli.completer + live_completer.load_schema_metadata( + schema=schema, + table_columns=table_columns, + foreign_keys=fk_payload, + enum_values=enum_values, + functions=functions, + procedures=procedures, + ) + self._invalidate_app() + + def _current_schema(self) -> str | None: + sqlexecute = self.mycli.sqlexecute + return sqlexecute.dbname if sqlexecute is not None else None + + def _make_executor(self) -> SQLExecute: + sqlexecute = self.mycli.sqlexecute + assert sqlexecute is not None + return SQLExecute( + sqlexecute.dbname, + sqlexecute.user, + sqlexecute.password, + sqlexecute.host, + sqlexecute.port, + sqlexecute.socket, + sqlexecute.character_set, + sqlexecute.local_infile, + sqlexecute.ssl, + sqlexecute.ssh_user, + sqlexecute.ssh_host, + sqlexecute.ssh_port, + sqlexecute.ssh_password, + sqlexecute.ssh_key_filename, + ) + + def _invalidate_app(self) -> None: + prompt_session = getattr(self.mycli, 'prompt_session', None) + if prompt_session is None: + return + try: + prompt_session.app.invalidate() + except Exception: # pragma: no cover - defensive + pass diff --git a/mycli/sqlcompleter.py b/mycli/sqlcompleter.py index 17363f48..67d8f8c0 100644 --- a/mycli/sqlcompleter.py +++ b/mycli/sqlcompleter.py @@ -1,270 +1,1000 @@ -import logging -from re import compile, escape +from __future__ import annotations + from collections import Counter +from enum import IntEnum +import logging +import re +from typing import Any, Collection, Generator, Iterable, Literal -from prompt_toolkit.completion import Completer, Completion +from prompt_toolkit.completion import CompleteEvent, Completer, Completion +from prompt_toolkit.completion.base import Document +from pygments.lexers._mysql_builtins import MYSQL_DATATYPES, MYSQL_FUNCTIONS, MYSQL_KEYWORDS +import rapidfuzz -from .packages.completion_engine import suggest_type -from .packages.parseutils import last_word -from .packages.filepaths import parse_path, complete_path, suggest_path -from .packages.special.favoritequeries import FavoriteQueries +from mycli.packages.completion_engine import is_inside_quotes, suggest_type +from mycli.packages.filepaths import complete_path, parse_path, suggest_path +from mycli.packages.special import llm +from mycli.packages.special.favoritequeries import FavoriteQueries +from mycli.packages.special.main import COMMANDS as SPECIAL_COMMANDS +from mycli.packages.sql_utils import extract_columns_from_select, extract_tables, last_word _logger = logging.getLogger(__name__) +_CASE_CHANGE_PAT = re.compile('(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])') + + +class Fuzziness(IntEnum): + PERFECT = 0 + REGEX = 1 + UNDER_WORDS = 2 + CAMEL_CASE = 3 + RAPIDFUZZ = 4 class SQLCompleter(Completer): - keywords = [ - 'SELECT', 'FROM', 'WHERE', 'UPDATE', 'DELETE FROM', 'GROUP BY', - 'JOIN', 'INSERT INTO', 'LIKE', 'LIMIT', 'ACCESS', 'ADD', 'ALL', - 'ALTER TABLE', 'AND', 'ANY', 'AS', 'ASC', 'AUTO_INCREMENT', - 'BEFORE', 'BEGIN', 'BETWEEN', 'BIGINT', 'BINARY', 'BY', 'CASE', - 'CHANGE MASTER TO', 'CHAR', 'CHARACTER SET', 'CHECK', 'COLLATE', - 'COLUMN', 'COMMENT', 'COMMIT', 'CONSTRAINT', 'CREATE', 'CURRENT', - 'CURRENT_TIMESTAMP', 'DATABASE', 'DATE', 'DECIMAL', 'DEFAULT', - 'DESC', 'DESCRIBE', 'DROP', 'ELSE', 'END', 'ENGINE', 'ESCAPE', - 'EXISTS', 'FILE', 'FLOAT', 'FOR', 'FOREIGN KEY', 'FORMAT', 'FULL', - 'FUNCTION', 'GRANT', 'HAVING', 'HOST', 'IDENTIFIED', 'IN', - 'INCREMENT', 'INDEX', 'INT', 'INTEGER', 'INTERVAL', 'INTO', 'IS', - 'KEY', 'LEFT', 'LEVEL', 'LOCK', 'LOGS', 'LONG', 'MASTER', - 'MEDIUMINT', 'MODE', 'MODIFY', 'NOT', 'NULL', 'NUMBER', 'OFFSET', - 'ON', 'OPTION', 'OR', 'ORDER BY', 'OUTER', 'OWNER', 'PASSWORD', - 'PORT', 'PRIMARY', 'PRIVILEGES', 'PROCESSLIST', 'PURGE', - 'REFERENCES', 'REGEXP', 'RENAME', 'REPAIR', 'RESET', 'REVOKE', - 'RIGHT', 'ROLLBACK', 'ROW', 'ROWS', 'ROW_FORMAT', 'SAVEPOINT', - 'SESSION', 'SET', 'SHARE', 'SHOW', 'SLAVE', 'SMALLINT', - 'START', 'STOP', 'TABLE', 'THEN', 'TINYINT', 'TO', 'TRANSACTION', - 'TRIGGER', 'TRUNCATE', 'UNION', 'UNIQUE', 'UNSIGNED', 'USE', - 'USER', 'USING', 'VALUES', 'VARCHAR', 'VIEW', 'WHEN', 'WITH' - ] + favorite_keywords = [ + 'SELECT', + 'FROM', + 'WHERE', + 'UPDATE', + 'DELETE FROM', + 'GROUP BY', + 'ORDER BY', + 'JOIN', + 'LEFT JOIN', + 'INSERT INTO', + 'LIKE', + 'LIMIT', + 'WITH', + 'EXPLAIN', + ] + keywords_raw = [ + x.upper() + for x in favorite_keywords + + list(MYSQL_DATATYPES) + + list(MYSQL_KEYWORDS) + + ['ALTER TABLE', 'CHANGE MASTER TO', 'CHARACTER SET', 'FOREIGN KEY'] + ] + keywords_d = dict.fromkeys(keywords_raw) + for x in SPECIAL_COMMANDS: + if x.upper() in keywords_d: + del keywords_d[x.upper()] + keywords = list(keywords_d) tidb_keywords = [ - "SELECT", "FROM", "WHERE", "DELETE FROM", "UPDATE", "GROUP BY", - "JOIN", "INSERT INTO", "LIKE", "LIMIT", "ACCOUNT", "ACTION", "ADD", - "ADDDATE", "ADMIN", "ADVISE", "AFTER", "AGAINST", "AGO", - "ALGORITHM", "ALL", "ALTER", "ALWAYS", "ANALYZE", "AND", "ANY", - "APPROX_COUNT_DISTINCT", "APPROX_PERCENTILE", "AS", "ASC", "ASCII", - "ATTRIBUTES", "AUTO_ID_CACHE", "AUTO_INCREMENT", "AUTO_RANDOM", - "AUTO_RANDOM_BASE", "AVG", "AVG_ROW_LENGTH", "BACKEND", "BACKUP", - "BACKUPS", "BATCH", "BEGIN", "BERNOULLI", "BETWEEN", "BIGINT", - "BINARY", "BINDING", "BINDINGS", "BINDING_CACHE", "BINLOG", "BIT", - "BIT_AND", "BIT_OR", "BIT_XOR", "BLOB", "BLOCK", "BOOL", "BOOLEAN", - "BOTH", "BOUND", "BRIEF", "BTREE", "BUCKETS", "BUILTINS", "BY", - "BYTE", "CACHE", "CALL", "CANCEL", "CAPTURE", "CARDINALITY", - "CASCADE", "CASCADED", "CASE", "CAST", "CAUSAL", "CHAIN", "CHANGE", - "CHAR", "CHARACTER", "CHARSET", "CHECK", "CHECKPOINT", "CHECKSUM", - "CIPHER", "CLEANUP", "CLIENT", "CLIENT_ERRORS_SUMMARY", - "CLUSTERED", "CMSKETCH", "COALESCE", "COLLATE", "COLLATION", - "COLUMN", "COLUMNS", "COLUMN_FORMAT", "COLUMN_STATS_USAGE", - "COMMENT", "COMMIT", "COMMITTED", "COMPACT", "COMPRESSED", - "COMPRESSION", "CONCURRENCY", "CONFIG", "CONNECTION", - "CONSISTENCY", "CONSISTENT", "CONSTRAINT", "CONSTRAINTS", - "CONTEXT", "CONVERT", "COPY", "CORRELATION", "CPU", "CREATE", - "CROSS", "CSV_BACKSLASH_ESCAPE", "CSV_DELIMITER", "CSV_HEADER", - "CSV_NOT_NULL", "CSV_NULL", "CSV_SEPARATOR", - "CSV_TRIM_LAST_SEPARATORS", "CUME_DIST", "CURRENT", "CURRENT_DATE", - "CURRENT_ROLE", "CURRENT_TIME", "CURRENT_TIMESTAMP", - "CURRENT_USER", "CURTIME", "CYCLE", "DATA", "DATABASE", - "DATABASES", "DATE", "DATETIME", "DATE_ADD", "DATE_SUB", "DAY", - "DAY_HOUR", "DAY_MICROSECOND", "DAY_MINUTE", "DAY_SECOND", "DDL", - "DEALLOCATE", "DECIMAL", "DEFAULT", "DEFINER", "DELAYED", - "DELAY_KEY_WRITE", "DENSE_RANK", "DEPENDENCY", "DEPTH", "DESC", - "DESCRIBE", "DIRECTORY", "DISABLE", "DISABLED", "DISCARD", "DISK", - "DISTINCT", "DISTINCTROW", "DIV", "DO", "DOT", "DOUBLE", "DRAINER", - "DROP", "DRY", "DUAL", "DUMP", "DUPLICATE", "DYNAMIC", "ELSE", - "ENABLE", "ENABLED", "ENCLOSED", "ENCRYPTION", "END", "ENFORCED", - "ENGINE", "ENGINES", "ENUM", "ERROR", "ERRORS", "ESCAPE", - "ESCAPED", "EVENT", "EVENTS", "EVOLVE", "EXACT", "EXCEPT", - "EXCHANGE", "EXCLUSIVE", "EXECUTE", "EXISTS", "EXPANSION", - "EXPIRE", "EXPLAIN", "EXPR_PUSHDOWN_BLACKLIST", "EXTENDED", - "EXTRACT", "FALSE", "FAST", "FAULTS", "FETCH", "FIELDS", "FILE", - "FIRST", "FIRST_VALUE", "FIXED", "FLASHBACK", "FLOAT", "FLUSH", - "FOLLOWER", "FOLLOWERS", "FOLLOWER_CONSTRAINTS", "FOLLOWING", - "FOR", "FORCE", "FOREIGN", "FORMAT", "FULL", "FULLTEXT", - "FUNCTION", "GENERAL", "GENERATED", "GET_FORMAT", "GLOBAL", - "GRANT", "GRANTS", "GROUPS", "GROUP_CONCAT", "HASH", "HAVING", - "HELP", "HIGH_PRIORITY", "HISTOGRAM", "HISTOGRAMS_IN_FLIGHT", - "HISTORY", "HOSTS", "HOUR", "HOUR_MICROSECOND", "HOUR_MINUTE", - "HOUR_SECOND", "IDENTIFIED", "IF", "IGNORE", "IMPORT", "IMPORTS", - "IN", "INCREMENT", "INCREMENTAL", "INDEX", "INDEXES", "INFILE", - "INNER", "INPLACE", "INSERT_METHOD", "INSTANCE", - "INSTANT", "INT", "INT1", "INT2", "INT3", "INT4", "INT8", - "INTEGER", "INTERNAL", "INTERSECT", "INTERVAL", "INTO", - "INVISIBLE", "INVOKER", "IO", "IPC", "IS", "ISOLATION", "ISSUER", - "JOB", "JOBS", "JSON", "JSON_ARRAYAGG", "JSON_OBJECTAGG", "KEY", - "KEYS", "KEY_BLOCK_SIZE", "KILL", "LABELS", "LAG", "LANGUAGE", - "LAST", "LASTVAL", "LAST_BACKUP", "LAST_VALUE", "LEAD", "LEADER", - "LEADER_CONSTRAINTS", "LEADING", "LEARNER", "LEARNERS", - "LEARNER_CONSTRAINTS", "LEFT", "LESS", "LEVEL", "LINEAR", "LINES", - "LIST", "LOAD", "LOCAL", "LOCALTIME", "LOCALTIMESTAMP", "LOCATION", - "LOCK", "LOCKED", "LOGS", "LONG", "LONGBLOB", "LONGTEXT", - "LOW_PRIORITY", "MASTER", "MATCH", "MAX", "MAXVALUE", - "MAX_CONNECTIONS_PER_HOUR", "MAX_IDXNUM", "MAX_MINUTES", - "MAX_QUERIES_PER_HOUR", "MAX_ROWS", "MAX_UPDATES_PER_HOUR", - "MAX_USER_CONNECTIONS", "MB", "MEDIUMBLOB", "MEDIUMINT", - "MEDIUMTEXT", "MEMORY", "MERGE", "MICROSECOND", "MIN", "MINUTE", - "MINUTE_MICROSECOND", "MINUTE_SECOND", "MINVALUE", "MIN_ROWS", - "MOD", "MODE", "MODIFY", "MONTH", "NAMES", "NATIONAL", "NATURAL", - "NCHAR", "NEVER", "NEXT", "NEXTVAL", "NEXT_ROW_ID", "NO", - "NOCACHE", "NOCYCLE", "NODEGROUP", "NODE_ID", "NODE_STATE", - "NOMAXVALUE", "NOMINVALUE", "NONCLUSTERED", "NONE", "NORMAL", - "NOT", "NOW", "NOWAIT", "NO_WRITE_TO_BINLOG", "NTH_VALUE", "NTILE", - "NULL", "NULLS", "NUMERIC", "NVARCHAR", "OF", "OFF", "OFFSET", - "ON", "ONLINE", "ONLY", "ON_DUPLICATE", "OPEN", "OPTIMISTIC", - "OPTIMIZE", "OPTION", "OPTIONAL", "OPTIONALLY", - "OPT_RULE_BLACKLIST", "OR", "ORDER", "OUTER", "OUTFILE", "OVER", - "PACK_KEYS", "PAGE", "PARSER", "PARTIAL", "PARTITION", - "PARTITIONING", "PARTITIONS", "PASSWORD", "PERCENT", - "PERCENT_RANK", "PER_DB", "PER_TABLE", "PESSIMISTIC", "PLACEMENT", - "PLAN", "PLAN_CACHE", "PLUGINS", "POLICY", "POSITION", "PRECEDING", - "PRECISION", "PREDICATE", "PREPARE", "PRESERVE", - "PRE_SPLIT_REGIONS", "PRIMARY", "PRIMARY_REGION", "PRIVILEGES", - "PROCEDURE", "PROCESS", "PROCESSLIST", "PROFILE", "PROFILES", - "PROXY", "PUMP", "PURGE", "QUARTER", "QUERIES", "QUERY", "QUICK", - "RANGE", "RANK", "RATE_LIMIT", "READ", "REAL", "REBUILD", "RECENT", - "RECOVER", "RECURSIVE", "REDUNDANT", "REFERENCES", "REGEXP", - "REGION", "REGIONS", "RELEASE", "RELOAD", "REMOVE", "RENAME", - "REORGANIZE", "REPAIR", "REPEAT", "REPEATABLE", "REPLACE", - "REPLAYER", "REPLICA", "REPLICAS", "REPLICATION", "REQUIRE", - "REQUIRED", "RESET", "RESPECT", "RESTART", "RESTORE", "RESTORES", - "RESTRICT", "RESUME", "REVERSE", "REVOKE", "RIGHT", "RLIKE", - "ROLE", "ROLLBACK", "ROUTINE", "ROW", "ROWS", "ROW_COUNT", - "ROW_FORMAT", "ROW_NUMBER", "RTREE", "RUN", "RUNNING", "S3", - "SAMPLERATE", "SAMPLES", "SAN", "SAVEPOINT", "SCHEDULE", "SECOND", - "SECONDARY_ENGINE", "SECONDARY_LOAD", "SECONDARY_UNLOAD", - "SECOND_MICROSECOND", "SECURITY", "SEND_CREDENTIALS_TO_TIKV", - "SEPARATOR", "SEQUENCE", "SERIAL", "SERIALIZABLE", "SESSION", - "SESSION_STATES", "SET", "SETVAL", "SHARD_ROW_ID_BITS", "SHARE", - "SHARED", "SHOW", "SHUTDOWN", "SIGNED", "SIMPLE", "SKIP", - "SKIP_SCHEMA_FILES", "SLAVE", "SLOW", "SMALLINT", "SNAPSHOT", - "SOME", "SOURCE", "SPATIAL", "SPLIT", "SQL", "SQL_BIG_RESULT", - "SQL_BUFFER_RESULT", "SQL_CACHE", "SQL_CALC_FOUND_ROWS", - "SQL_NO_CACHE", "SQL_SMALL_RESULT", "SQL_TSI_DAY", "SQL_TSI_HOUR", - "SQL_TSI_MINUTE", "SQL_TSI_MONTH", "SQL_TSI_QUARTER", - "SQL_TSI_SECOND", "SQL_TSI_WEEK", "SQL_TSI_YEAR", "SSL", - "STALENESS", "START", "STARTING", "STATISTICS", "STATS", - "STATS_AUTO_RECALC", "STATS_BUCKETS", "STATS_COL_CHOICE", - "STATS_COL_LIST", "STATS_EXTENDED", "STATS_HEALTHY", - "STATS_HISTOGRAMS", "STATS_META", "STATS_OPTIONS", - "STATS_PERSISTENT", "STATS_SAMPLE_PAGES", "STATS_SAMPLE_RATE", - "STATS_TOPN", "STATUS", "STD", "STDDEV", "STDDEV_POP", - "STDDEV_SAMP", "STOP", "STORAGE", "STORED", "STRAIGHT_JOIN", - "STRICT", "STRICT_FORMAT", "STRONG", "SUBDATE", "SUBJECT", - "SUBPARTITION", "SUBPARTITIONS", "SUBSTRING", "SUM", "SUPER", - "SWAPS", "SWITCHES", "SYSTEM", "SYSTEM_TIME", "TABLE", "TABLES", - "TABLESAMPLE", "TABLESPACE", "TABLE_CHECKSUM", "TARGET", - "TELEMETRY", "TELEMETRY_ID", "TEMPORARY", "TEMPTABLE", - "TERMINATED", "TEXT", "THAN", "THEN", "TIDB", "TIFLASH", - "TIKV_IMPORTER", "TIME", "TIMESTAMP", "TIMESTAMPADD", - "TIMESTAMPDIFF", "TINYBLOB", "TINYINT", "TINYTEXT", "TLS", "TO", - "TOKUDB_DEFAULT", "TOKUDB_FAST", "TOKUDB_LZMA", "TOKUDB_QUICKLZ", - "TOKUDB_SMALL", "TOKUDB_SNAPPY", "TOKUDB_UNCOMPRESSED", - "TOKUDB_ZLIB", "TOP", "TOPN", "TRACE", "TRADITIONAL", "TRAILING", - "TRANSACTION", "TRIGGER", "TRIGGERS", "TRIM", "TRUE", - "TRUE_CARD_COST", "TRUNCATE", "TYPE", "UNBOUNDED", "UNCOMMITTED", - "UNDEFINED", "UNICODE", "UNION", "UNIQUE", "UNKNOWN", "UNLOCK", - "UNSIGNED", "USAGE", "USE", "USER", "USING", "UTC_DATE", - "UTC_TIME", "UTC_TIMESTAMP", "VALIDATION", "VALUE", "VALUES", - "VARBINARY", "VARCHAR", "VARCHARACTER", "VARIABLES", "VARIANCE", - "VARYING", "VAR_POP", "VAR_SAMP", "VERBOSE", "VIEW", "VIRTUAL", - "VISIBLE", "VOTER", "VOTERS", "VOTER_CONSTRAINTS", "WAIT", - "WARNINGS", "WEEK", "WEIGHT_STRING", "WHEN", "WIDTH", "WINDOW", - "WITH", "WITHOUT", "WRITE", "X509", "XOR", "YEAR", "YEAR_MONTH", - "ZEROFILL" - ] - - functions = ['AVG', 'CONCAT', 'COUNT', 'DISTINCT', 'FIRST', 'FORMAT', - 'FROM_UNIXTIME', 'LAST', 'LCASE', 'LEN', 'MAX', 'MID', - 'MIN', 'NOW', 'ROUND', 'SUM', 'TOP', 'UCASE', - 'UNIX_TIMESTAMP' - ] + "SELECT", + "FROM", + "WHERE", + "DELETE FROM", + "UPDATE", + "GROUP BY", + "JOIN", + "INSERT INTO", + "LIKE", + "LIMIT", + "ACCOUNT", + "ACTION", + "ADD", + "ADDDATE", + "ADMIN", + "ADVISE", + "AFTER", + "AGAINST", + "AGO", + "ALGORITHM", + "ALL", + "ALTER", + "ALWAYS", + "ANALYZE", + "AND", + "ANY", + "APPROX_COUNT_DISTINCT", + "APPROX_PERCENTILE", + "AS", + "ASC", + "ASCII", + "ATTRIBUTES", + "AUTO_ID_CACHE", + "AUTO_INCREMENT", + "AUTO_RANDOM", + "AUTO_RANDOM_BASE", + "AVG", + "AVG_ROW_LENGTH", + "BACKEND", + "BACKUP", + "BACKUPS", + "BATCH", + "BEGIN", + "BERNOULLI", + "BETWEEN", + "BIGINT", + "BINARY", + "BINDING", + "BINDINGS", + "BINDING_CACHE", + "BINLOG", + "BIT", + "BIT_AND", + "BIT_OR", + "BIT_XOR", + "BLOB", + "BLOCK", + "BOOL", + "BOOLEAN", + "BOTH", + "BOUND", + "BRIEF", + "BTREE", + "BUCKETS", + "BUILTINS", + "BY", + "BYTE", + "CACHE", + "CALL", + "CANCEL", + "CAPTURE", + "CARDINALITY", + "CASCADE", + "CASCADED", + "CASE", + "CAST", + "CAUSAL", + "CHAIN", + "CHANGE", + "CHAR", + "CHARACTER", + "CHARSET", + "CHECK", + "CHECKPOINT", + "CHECKSUM", + "CIPHER", + "CLEANUP", + "CLIENT", + "CLIENT_ERRORS_SUMMARY", + "CLUSTERED", + "CMSKETCH", + "COALESCE", + "COLLATE", + "COLLATION", + "COLUMN", + "COLUMNS", + "COLUMN_FORMAT", + "COLUMN_STATS_USAGE", + "COMMENT", + "COMMIT", + "COMMITTED", + "COMPACT", + "COMPRESSED", + "COMPRESSION", + "CONCURRENCY", + "CONFIG", + "CONNECTION", + "CONSISTENCY", + "CONSISTENT", + "CONSTRAINT", + "CONSTRAINTS", + "CONTEXT", + "CONVERT", + "COPY", + "CORRELATION", + "CPU", + "CREATE", + "CROSS", + "CSV_BACKSLASH_ESCAPE", + "CSV_DELIMITER", + "CSV_HEADER", + "CSV_NOT_NULL", + "CSV_NULL", + "CSV_SEPARATOR", + "CSV_TRIM_LAST_SEPARATORS", + "CUME_DIST", + "CURRENT", + "CURRENT_DATE", + "CURRENT_ROLE", + "CURRENT_TIME", + "CURRENT_TIMESTAMP", + "CURRENT_USER", + "CURTIME", + "CYCLE", + "DATA", + "DATABASE", + "DATABASES", + "DATE", + "DATETIME", + "DATE_ADD", + "DATE_SUB", + "DAY", + "DAY_HOUR", + "DAY_MICROSECOND", + "DAY_MINUTE", + "DAY_SECOND", + "DDL", + "DEALLOCATE", + "DECIMAL", + "DEFAULT", + "DEFINER", + "DELAYED", + "DELAY_KEY_WRITE", + "DENSE_RANK", + "DEPENDENCY", + "DEPTH", + "DESC", + "DESCRIBE", + "DIRECTORY", + "DISABLE", + "DISABLED", + "DISCARD", + "DISK", + "DISTINCT", + "DISTINCTROW", + "DIV", + "DO", + "DOT", + "DOUBLE", + "DRAINER", + "DROP", + "DRY", + "DUAL", + "DUMP", + "DUPLICATE", + "DYNAMIC", + "ELSE", + "ENABLE", + "ENABLED", + "ENCLOSED", + "ENCRYPTION", + "END", + "ENFORCED", + "ENGINE", + "ENGINES", + "ENUM", + "ERROR", + "ERRORS", + "ESCAPE", + "ESCAPED", + "EVENT", + "EVENTS", + "EVOLVE", + "EXACT", + "EXCEPT", + "EXCHANGE", + "EXCLUSIVE", + "EXECUTE", + "EXISTS", + "EXPANSION", + "EXPIRE", + "EXPLAIN", + "EXPR_PUSHDOWN_BLACKLIST", + "EXTENDED", + "EXTRACT", + "FALSE", + "FAST", + "FAULTS", + "FETCH", + "FIELDS", + "FILE", + "FIRST", + "FIRST_VALUE", + "FIXED", + "FLASHBACK", + "FLOAT", + "FLUSH", + "FOLLOWER", + "FOLLOWERS", + "FOLLOWER_CONSTRAINTS", + "FOLLOWING", + "FOR", + "FORCE", + "FOREIGN", + "FORMAT", + "FULL", + "FULLTEXT", + "FUNCTION", + "GENERAL", + "GENERATED", + "GET_FORMAT", + "GLOBAL", + "GRANT", + "GRANTS", + "GROUPS", + "GROUP_CONCAT", + "HASH", + "HAVING", + "HELP", + "HIGH_PRIORITY", + "HISTOGRAM", + "HISTOGRAMS_IN_FLIGHT", + "HISTORY", + "HOSTS", + "HOUR", + "HOUR_MICROSECOND", + "HOUR_MINUTE", + "HOUR_SECOND", + "IDENTIFIED", + "IF", + "IGNORE", + "IMPORT", + "IMPORTS", + "IN", + "INCREMENT", + "INCREMENTAL", + "INDEX", + "INDEXES", + "INFILE", + "INNER", + "INPLACE", + "INSERT_METHOD", + "INSTANCE", + "INSTANT", + "INT", + "INT1", + "INT2", + "INT3", + "INT4", + "INT8", + "INTEGER", + "INTERNAL", + "INTERSECT", + "INTERVAL", + "INTO", + "INVISIBLE", + "INVOKER", + "IO", + "IPC", + "IS", + "ISOLATION", + "ISSUER", + "JOB", + "JOBS", + "JSON", + "JSON_ARRAYAGG", + "JSON_OBJECTAGG", + "KEY", + "KEYS", + "KEY_BLOCK_SIZE", + "KILL", + "LABELS", + "LAG", + "LANGUAGE", + "LAST", + "LASTVAL", + "LAST_BACKUP", + "LAST_VALUE", + "LEAD", + "LEADER", + "LEADER_CONSTRAINTS", + "LEADING", + "LEARNER", + "LEARNERS", + "LEARNER_CONSTRAINTS", + "LEFT", + "LESS", + "LEVEL", + "LINEAR", + "LINES", + "LIST", + "LOAD", + "LOCAL", + "LOCALTIME", + "LOCALTIMESTAMP", + "LOCATION", + "LOCK", + "LOCKED", + "LOGS", + "LONG", + "LONGBLOB", + "LONGTEXT", + "LOW_PRIORITY", + "MASTER", + "MATCH", + "MAX", + "MAXVALUE", + "MAX_CONNECTIONS_PER_HOUR", + "MAX_IDXNUM", + "MAX_MINUTES", + "MAX_QUERIES_PER_HOUR", + "MAX_ROWS", + "MAX_UPDATES_PER_HOUR", + "MAX_USER_CONNECTIONS", + "MB", + "MEDIUMBLOB", + "MEDIUMINT", + "MEDIUMTEXT", + "MEMORY", + "MERGE", + "MICROSECOND", + "MIN", + "MINUTE", + "MINUTE_MICROSECOND", + "MINUTE_SECOND", + "MINVALUE", + "MIN_ROWS", + "MOD", + "MODE", + "MODIFY", + "MONTH", + "NAMES", + "NATIONAL", + "NATURAL", + "NCHAR", + "NEVER", + "NEXT", + "NEXTVAL", + "NEXT_ROW_ID", + "NO", + "NOCACHE", + "NOCYCLE", + "NODEGROUP", + "NODE_ID", + "NODE_STATE", + "NOMAXVALUE", + "NOMINVALUE", + "NONCLUSTERED", + "NONE", + "NORMAL", + "NOT", + "NOW", + "NOWAIT", + "NO_WRITE_TO_BINLOG", + "NTH_VALUE", + "NTILE", + "NULL", + "NULLS", + "NUMERIC", + "NVARCHAR", + "OF", + "OFF", + "OFFSET", + "ON", + "ONLINE", + "ONLY", + "ON_DUPLICATE", + "OPEN", + "OPTIMISTIC", + "OPTIMIZE", + "OPTION", + "OPTIONAL", + "OPTIONALLY", + "OPT_RULE_BLACKLIST", + "OR", + "ORDER", + "OUTER", + "OUTFILE", + "OVER", + "PACK_KEYS", + "PAGE", + "PARSER", + "PARTIAL", + "PARTITION", + "PARTITIONING", + "PARTITIONS", + "PASSWORD", + "PERCENT", + "PERCENT_RANK", + "PER_DB", + "PER_TABLE", + "PESSIMISTIC", + "PLACEMENT", + "PLAN", + "PLAN_CACHE", + "PLUGINS", + "POLICY", + "POSITION", + "PRECEDING", + "PRECISION", + "PREDICATE", + "PREPARE", + "PRESERVE", + "PRE_SPLIT_REGIONS", + "PRIMARY", + "PRIMARY_REGION", + "PRIVILEGES", + "PROCEDURE", + "PROCESS", + "PROCESSLIST", + "PROFILE", + "PROFILES", + "PROXY", + "PUMP", + "PURGE", + "QUARTER", + "QUERIES", + "QUERY", + "QUICK", + "RANGE", + "RANK", + "RATE_LIMIT", + "READ", + "REAL", + "REBUILD", + "RECENT", + "RECOVER", + "RECURSIVE", + "REDUNDANT", + "REFERENCES", + "REGEXP", + "REGION", + "REGIONS", + "RELEASE", + "RELOAD", + "REMOVE", + "RENAME", + "REORGANIZE", + "REPAIR", + "REPEAT", + "REPEATABLE", + "REPLACE", + "REPLAYER", + "REPLICA", + "REPLICAS", + "REPLICATION", + "REQUIRE", + "REQUIRED", + "RESET", + "RESPECT", + "RESTART", + "RESTORE", + "RESTORES", + "RESTRICT", + "RESUME", + "REVERSE", + "REVOKE", + "RIGHT", + "RLIKE", + "ROLE", + "ROLLBACK", + "ROUTINE", + "ROW", + "ROWS", + "ROW_COUNT", + "ROW_FORMAT", + "ROW_NUMBER", + "RTREE", + "RUN", + "RUNNING", + "S3", + "SAMPLERATE", + "SAMPLES", + "SAN", + "SAVEPOINT", + "SCHEDULE", + "SECOND", + "SECONDARY_ENGINE", + "SECONDARY_LOAD", + "SECONDARY_UNLOAD", + "SECOND_MICROSECOND", + "SECURITY", + "SEND_CREDENTIALS_TO_TIKV", + "SEPARATOR", + "SEQUENCE", + "SERIAL", + "SERIALIZABLE", + "SESSION", + "SESSION_STATES", + "SET", + "SETVAL", + "SHARD_ROW_ID_BITS", + "SHARE", + "SHARED", + "SHOW", + "SHUTDOWN", + "SIGNED", + "SIMPLE", + "SKIP", + "SKIP_SCHEMA_FILES", + "SLAVE", + "SLOW", + "SMALLINT", + "SNAPSHOT", + "SOME", + "SOURCE", + "SPATIAL", + "SPLIT", + "SQL", + "SQL_BIG_RESULT", + "SQL_BUFFER_RESULT", + "SQL_CACHE", + "SQL_CALC_FOUND_ROWS", + "SQL_NO_CACHE", + "SQL_SMALL_RESULT", + "SQL_TSI_DAY", + "SQL_TSI_HOUR", + "SQL_TSI_MINUTE", + "SQL_TSI_MONTH", + "SQL_TSI_QUARTER", + "SQL_TSI_SECOND", + "SQL_TSI_WEEK", + "SQL_TSI_YEAR", + "SSL", + "STALENESS", + "START", + "STARTING", + "STATISTICS", + "STATS", + "STATS_AUTO_RECALC", + "STATS_BUCKETS", + "STATS_COL_CHOICE", + "STATS_COL_LIST", + "STATS_EXTENDED", + "STATS_HEALTHY", + "STATS_HISTOGRAMS", + "STATS_META", + "STATS_OPTIONS", + "STATS_PERSISTENT", + "STATS_SAMPLE_PAGES", + "STATS_SAMPLE_RATE", + "STATS_TOPN", + "STATUS", + "STD", + "STDDEV", + "STDDEV_POP", + "STDDEV_SAMP", + "STOP", + "STORAGE", + "STORED", + "STRAIGHT_JOIN", + "STRICT", + "STRICT_FORMAT", + "STRONG", + "SUBDATE", + "SUBJECT", + "SUBPARTITION", + "SUBPARTITIONS", + "SUBSTRING", + "SUM", + "SUPER", + "SWAPS", + "SWITCHES", + "SYSTEM", + "SYSTEM_TIME", + "TABLE", + "TABLES", + "TABLESAMPLE", + "TABLESPACE", + "TABLE_CHECKSUM", + "TARGET", + "TELEMETRY", + "TELEMETRY_ID", + "TEMPORARY", + "TEMPTABLE", + "TERMINATED", + "TEXT", + "THAN", + "THEN", + "TIDB", + "TIFLASH", + "TIKV_IMPORTER", + "TIME", + "TIMESTAMP", + "TIMESTAMPADD", + "TIMESTAMPDIFF", + "TINYBLOB", + "TINYINT", + "TINYTEXT", + "TLS", + "TO", + "TOKUDB_DEFAULT", + "TOKUDB_FAST", + "TOKUDB_LZMA", + "TOKUDB_QUICKLZ", + "TOKUDB_SMALL", + "TOKUDB_SNAPPY", + "TOKUDB_UNCOMPRESSED", + "TOKUDB_ZLIB", + "TOP", + "TOPN", + "TRACE", + "TRADITIONAL", + "TRAILING", + "TRANSACTION", + "TRIGGER", + "TRIGGERS", + "TRIM", + "TRUE", + "TRUE_CARD_COST", + "TRUNCATE", + "TYPE", + "UNBOUNDED", + "UNCOMMITTED", + "UNDEFINED", + "UNICODE", + "UNION", + "UNIQUE", + "UNKNOWN", + "UNLOCK", + "UNSIGNED", + "USAGE", + "USE", + "USER", + "USING", + "UTC_DATE", + "UTC_TIME", + "UTC_TIMESTAMP", + "VALIDATION", + "VALUE", + "VALUES", + "VARBINARY", + "VARCHAR", + "VARCHARACTER", + "VARIABLES", + "VARIANCE", + "VARYING", + "VAR_POP", + "VAR_SAMP", + "VERBOSE", + "VIEW", + "VIRTUAL", + "VISIBLE", + "VOTER", + "VOTERS", + "VOTER_CONSTRAINTS", + "WAIT", + "WARNINGS", + "WEEK", + "WEIGHT_STRING", + "WHEN", + "WIDTH", + "WINDOW", + "WITH", + "WITHOUT", + "WRITE", + "X509", + "XOR", + "YEAR", + "YEAR_MONTH", + "ZEROFILL", + ] + + # misclassified as keywords + # do they need to also be subtracted from keywords? + pygments_misclassified_functions = [ + 'ASCII', + 'AVG', + 'CHARSET', + 'COALESCE', + 'COLLATION', + 'CONVERT', + 'CUME_DIST', + 'CURRENT_DATE', + 'CURRENT_TIME', + 'CURRENT_TIMESTAMP', + 'CURRENT_USER', + 'DATABASE', + 'DAY', + 'DEFAULT', + 'DENSE_RANK', + 'EXISTS', + 'FIRST_VALUE', + 'FORMAT', + 'GEOMCOLLECTION', + 'GET_FORMAT', + 'GROUPING', + 'HOUR', + 'IF', + 'INSERT', + 'INTERVAL', + 'JSON_TABLE', + 'JSON_VALUE', + 'LAG', + 'LAST_VALUE', + 'LEAD', + 'LEFT', + 'LOCALTIME', + 'LOCALTIMESTAMP', + 'MATCH', + 'MICROSECOND', + 'MINUTE', + 'MOD', + 'MONTH', + 'NTH_VALUE', + 'NTILE', + 'PERCENT_RANK', + 'QUARTER', + 'RANK', + 'REPEAT', + 'REPLACE', + 'REVERSE', + 'RIGHT', + 'ROW_COUNT', + 'ROW_NUMBER', + 'SCHEMA', + 'SECOND', + 'TIMESTAMPADD', + 'TIMESTAMPDIFF', + 'TRUNCATE', + 'USER', + 'UTC_DATE', + 'UTC_TIME', + 'UTC_TIMESTAMP', + 'VALUES', + 'WEEK', + 'WEIGHT_STRING', + ] + + # should case be respected for functions styled as CamelCase? + pygments_missing_functions = [ + 'BINARY', # deprecated function, but available everywhere + 'CHAR', + 'DATE', + 'DISTANCE', + 'ETAG', + 'GeometryCollection', + 'JSON_DUALITY_OBJECT', + 'LineString', + 'MultiLineString', + 'MultiPoint', + 'MultiPolygon', + 'Point', + 'Polygon', + 'STRING_TO_VECTOR', + 'TIME', + 'TIMESTAMP', + 'VECTOR_DIM', + 'VECTOR_TO_STRING', + 'YEAR', + ] + + # so far an incomplete list + # these should be spun out and completed independently from functions in the value position + pygments_value_position_nonfunction_keywords = [ + 'BETWEEN', + 'CASE', + 'DISTINCT', + 'FALSE', + 'NOT', + 'NULL', + 'TRUE', + ] + + # should https://dev.mysql.com/doc/refman/9.6/en/loadable-function-reference.html also be added? + pygments_functions_supplemented = sorted( + [x.upper() for x in MYSQL_FUNCTIONS] + + [x.upper() for x in pygments_misclassified_functions] + + [x.upper() for x in pygments_missing_functions] + + [x.upper() for x in pygments_value_position_nonfunction_keywords] + ) + + favorite_functions = [ + 'COUNT', + 'CONVERT', + 'BINARY', + 'CAST', + 'COALESCE', + 'MAX', + 'MIN', + 'SUM', + 'AVG', + 'JSON_EXTRACT', + 'JSON_VALUE', + 'JSON_REMOVE', + 'JSON_SET', + 'CONCAT', + 'GROUP_CONCAT', + 'CHAR_LENGTH', + 'ROUND', + 'FLOOR', + 'CEIL', + 'IF', + 'IFNULL', + 'SUBSTR', + 'SUBSTRING_INDEX', + 'REPLACE', + 'RIGHT', + 'LEFT', + 'UNIX_TIMESTAMP', + 'FROM_UNIXTIME', + 'RAND', + 'DATEDIFF', + 'DATE_SUB', + ] + functions_raw = favorite_functions + pygments_functions_supplemented + functions = list(dict.fromkeys(functions_raw)) # https://docs.pingcap.com/tidb/dev/tidb-functions tidb_functions = [ - 'TIDB_BOUNDED_STALENESS', 'TIDB_DECODE_KEY', 'TIDB_DECODE_PLAN', - 'TIDB_IS_DDL_OWNER', 'TIDB_PARSE_TSO', 'TIDB_VERSION', - 'TIDB_DECODE_SQL_DIGESTS', 'VITESS_HASH', 'TIDB_SHARD' - ] - - - show_items = [] - - change_items = ['MASTER_BIND', 'MASTER_HOST', 'MASTER_USER', - 'MASTER_PASSWORD', 'MASTER_PORT', 'MASTER_CONNECT_RETRY', - 'MASTER_HEARTBEAT_PERIOD', 'MASTER_LOG_FILE', - 'MASTER_LOG_POS', 'RELAY_LOG_FILE', 'RELAY_LOG_POS', - 'MASTER_SSL', 'MASTER_SSL_CA', 'MASTER_SSL_CAPATH', - 'MASTER_SSL_CERT', 'MASTER_SSL_KEY', 'MASTER_SSL_CIPHER', - 'MASTER_SSL_VERIFY_SERVER_CERT', 'IGNORE_SERVER_IDS'] - - users = [] - - def __init__(self, smart_completion=True, supported_formats=(), keyword_casing='auto'): + "TIDB_BOUNDED_STALENESS", + "TIDB_DECODE_KEY", + "TIDB_DECODE_PLAN", + "TIDB_IS_DDL_OWNER", + "TIDB_PARSE_TSO", + "TIDB_VERSION", + "TIDB_DECODE_SQL_DIGESTS", + "VITESS_HASH", + "TIDB_SHARD", + ] + + show_items: list[Completion] = [] + + change_items = [ + "MASTER_BIND", + "MASTER_HOST", + "MASTER_USER", + "MASTER_PASSWORD", + "MASTER_PORT", + "MASTER_CONNECT_RETRY", + "MASTER_HEARTBEAT_PERIOD", + "MASTER_LOG_FILE", + "MASTER_LOG_POS", + "RELAY_LOG_FILE", + "RELAY_LOG_POS", + "MASTER_SSL", + "MASTER_SSL_CA", + "MASTER_SSL_CAPATH", + "MASTER_SSL_CERT", + "MASTER_SSL_KEY", + "MASTER_SSL_CIPHER", + "MASTER_SSL_VERIFY_SERVER_CERT", + "IGNORE_SERVER_IDS", + ] + + users: list[str] = [] + + character_sets: list[str] = [] + + collations: list[str] = [] + + def __init__( + self, + smart_completion: bool = True, + supported_formats: tuple = (), + keyword_casing: str = "auto", + ) -> None: super(self.__class__, self).__init__() self.smart_completion = smart_completion self.reserved_words = set() for x in self.keywords: self.reserved_words.update(x.split()) - self.name_pattern = compile(r"^[_a-z][_a-z0-9\$]*$") + self.name_pattern = re.compile(r"^[_a-zA-Z][_a-zA-Z0-9\$]*$") - self.special_commands = [] + self.special_commands: list[str] = [] self.table_formats = supported_formats - if keyword_casing not in ('upper', 'lower', 'auto'): - keyword_casing = 'auto' + if keyword_casing not in ("upper", "lower", "auto"): + keyword_casing = "auto" self.keyword_casing = keyword_casing self.reset_completions() - def escape_name(self, name): - if name and ((not self.name_pattern.match(name)) - or (name.upper() in self.reserved_words) - or (name.upper() in self.functions)): - name = '`%s`' % name + def escape_name(self, name: str) -> str: + if name and ((not self.name_pattern.match(name)) or (name.upper() in self.reserved_words) or (name.upper() in self.functions)): + name = f'`{name}`' return name - def unescape_name(self, name): - """Unquote a string.""" - if name and name[0] == '"' and name[-1] == '"': - name = name[1:-1] - - return name - - def escaped_names(self, names): + def escaped_names(self, names: Collection[str]) -> list[str]: return [self.escape_name(name) for name in names] - def extend_special_commands(self, special_commands): + def extend_special_commands(self, special_commands: list[str]) -> None: # Special commands are not part of all_completions since they can only # be at the beginning of a line. self.special_commands.extend(special_commands) - def extend_database_names(self, databases): - self.databases.extend(databases) + def extend_database_names(self, databases: list[str]) -> None: + self.databases.extend([self.escape_name(db) for db in databases]) - def extend_keywords(self, keywords, replace=False): + def extend_keywords(self, keywords: list[str], replace: bool = False) -> None: if replace: self.keywords = keywords else: self.keywords.extend(keywords) self.all_completions.update(keywords) - def extend_show_items(self, show_items): + def extend_show_items(self, show_items: Iterable[tuple]) -> None: for show_item in show_items: self.show_items.extend(show_item) self.all_completions.update(show_item) - def extend_change_items(self, change_items): + def extend_change_items(self, change_items: Iterable[tuple]) -> None: for change_item in change_items: self.change_items.extend(change_item) self.all_completions.update(change_item) - def extend_users(self, users): + def extend_users(self, users: Iterable[tuple]) -> None: for user in users: self.users.extend(user) self.all_completions.update(user) - def extend_schemata(self, schema): + def extend_schemata(self, schema: str | None) -> None: if schema is None: return - metadata = self.dbmetadata['tables'] + metadata = self.dbmetadata["tables"] metadata[schema] = {} # dbmetadata.values() are the 'tables' and 'functions' dicts @@ -272,58 +1002,107 @@ def extend_schemata(self, schema): metadata[schema] = {} self.all_completions.update(schema) - def extend_relations(self, data, kind): + def extend_relations(self, data: list[tuple[str, str]], kind: Literal['tables', 'views']) -> None: """Extend metadata for tables or views :param data: list of (rel_name, ) tuples :param kind: either 'tables' or 'views' :return: """ - # 'data' is a generator object. It can throw an exception while being - # consumed. This could happen if the user has launched the app without - # specifying a database name. This exception must be handled to prevent - # crashing. - try: - data = [self.escaped_names(d) for d in data] - except Exception: - data = [] + data_ll = [self.escaped_names(d) for d in data] # dbmetadata['tables'][$schema_name][$table_name] should be a list of # column names. Default to an asterisk metadata = self.dbmetadata[kind] - for relname in data: + for relname in data_ll: try: - metadata[self.dbname][relname[0]] = ['*'] + metadata[self.dbname][relname[0]] = ["*"] except KeyError: - _logger.error('%r %r listed in unrecognized schema %r', - kind, relname[0], self.dbname) + _logger.error("%r %r listed in unrecognized schema %r", kind, relname[0], self.dbname) self.all_completions.add(relname[0]) - def extend_columns(self, column_data, kind): + def extend_columns(self, column_data: list[tuple[str, str]], kind: Literal['tables', 'views']) -> None: """Extend column metadata :param column_data: list of (rel_name, column_name) tuples :param kind: either 'tables' or 'views' :return: """ - # 'column_data' is a generator object. It can throw an exception while - # being consumed. This could happen if the user has launched the app - # without specifying a database name. This exception must be handled to - # prevent crashing. - try: - column_data = [self.escaped_names(d) for d in column_data] - except Exception: - column_data = [] + column_data_ll = [self.escaped_names(d) for d in column_data] metadata = self.dbmetadata[kind] - for relname, column in column_data: + for relname, column in column_data_ll: + if relname not in metadata[self.dbname]: + _logger.error("relname '%s' was not found in db '%s'", relname, self.dbname) + # this could happen back when the completer populated via two calls: + # SHOW TABLES then SELECT table_name, column_name from information_schema.columns + # it's a slight race, but much more likely on Vitess picking random shards for each. + # see discussion in https://github.com/dbcli/mycli/pull/1182 (tl;dr - let's keep it) + continue metadata[self.dbname][relname].append(column) self.all_completions.add(column) - def extend_functions(self, func_data, builtin=False): + def extend_enum_values(self, enum_data: Iterable[tuple[str, str, list[str]]]) -> None: + metadata = self.dbmetadata["enum_values"] + if self.dbname not in metadata: + metadata[self.dbname] = {} + + for relname, column, values in enum_data: + relname_escaped = self.escape_name(relname) + column_escaped = self.escape_name(column) + table_meta = metadata[self.dbname].setdefault(relname_escaped, {}) + table_meta[column_escaped] = values + + def extend_foreign_keys(self, fk_data: Iterable[tuple[str, str, str, str]]) -> None: + """Extend FK metadata. + + :param fk_data: iterable of (table_name, column_name, referenced_table_name, referenced_column_name) + """ + metadata = self.dbmetadata["foreign_keys"] + schema_meta = metadata.setdefault(self.dbname, {}) + schema_meta.setdefault("tables", {}) + schema_meta.setdefault("relations", []) + for table, col, ref_table, ref_col in fk_data: + table = self.escape_name(table) + col = self.escape_name(col) + ref_table = self.escape_name(ref_table) + ref_col = self.escape_name(ref_col) + schema_meta["tables"].setdefault(table, set()).add(ref_table) + schema_meta["tables"].setdefault(ref_table, set()).add(table) + schema_meta["relations"].append((table, col, ref_table, ref_col)) + + def _fk_join_conditions(self, tables: list[tuple[str | None, str, str]]) -> list[str]: + """Return FK-based join condition strings for the tables currently in the query. + + For each FK relation where both the FK table and the referenced table appear in + *tables*, yields a string like ``alias1.col = alias2.ref_col`` (using the alias + when one exists, otherwise the table name). + """ + schema_meta = self.dbmetadata["foreign_keys"].get(self.dbname, {}) + relations = schema_meta.get("relations", []) + + # Map escaped table name -> alias (or table name when no alias). + # Skip tables from a different schema; we only have FK metadata for the current db. + alias_map: dict[str, str] = {} + for tbl_schema, tbl, alias in tables: + if tbl_schema and tbl_schema != self.dbname: + continue + escaped = self.escape_name(tbl) + alias_map[escaped] = alias or tbl + + conditions: list[str] = [] + for fk_table, fk_col, ref_table, ref_col in relations: + lhs = alias_map.get(fk_table) + rhs = alias_map.get(ref_table) + if lhs and rhs: + conditions.append(f"{lhs}.{fk_col} = {rhs}.{ref_col}") + return conditions + + def extend_functions(self, func_data: list[str] | Generator[tuple[str, str]], builtin: bool = False) -> None: # if 'builtin' is set this is extending the list of builtin functions if builtin: - self.functions.extend(func_data) + if isinstance(func_data, list): + self.functions.extend(func_data) return # 'func_data' is a generator object. It can throw an exception while @@ -331,31 +1110,271 @@ def extend_functions(self, func_data, builtin=False): # without specifying a database name. This exception must be handled to # prevent crashing. try: - func_data = [self.escaped_names(d) for d in func_data] + func_data_ll = [self.escaped_names(d) for d in func_data] except Exception: - func_data = [] + func_data_ll = [] # dbmetadata['functions'][$schema_name][$function_name] should return # function metadata. - metadata = self.dbmetadata['functions'] + metadata = self.dbmetadata["functions"] - for func in func_data: + for func in func_data_ll: metadata[self.dbname][func[0]] = None self.all_completions.add(func[0]) - def set_dbname(self, dbname): - self.dbname = dbname + def extend_procedures(self, procedure_data: Generator[tuple]) -> None: + metadata = self.dbmetadata["procedures"] + if self.dbname not in metadata: + metadata[self.dbname] = {} + + for elt in procedure_data: + # not sure why this happens on MariaDB in some cases + # see https://github.com/dbcli/mycli/issues/1531 + if not elt: + continue + if not elt[0]: + continue + metadata[self.dbname][elt[0]] = None - def reset_completions(self): - self.databases = [] - self.users = [] - self.show_items = [] - self.dbname = '' - self.dbmetadata = {'tables': {}, 'views': {}, 'functions': {}} + def extend_character_sets(self, character_set_data: Generator[tuple]) -> None: + for elt in character_set_data: + if not elt: + continue + if not elt[0]: + continue + self.character_sets.append(elt[0]) + self.all_completions.update(elt[0]) + + def extend_collations(self, collation_data: Generator[tuple]) -> None: + for elt in collation_data: + if not elt: + continue + if not elt[0]: + continue + self.collations.append(elt[0]) + self.all_completions.update(elt[0]) + + def set_dbname(self, dbname: str | None) -> None: + self.dbname = dbname or '' + + def load_schema_metadata( + self, + schema: str, + table_columns: dict[str, list[str]], + foreign_keys: dict[str, Any], + enum_values: dict[str, dict[str, list[str]]], + functions: dict[str, None], + procedures: dict[str, None], + ) -> None: + """Atomically replace the completion metadata for *schema*. + + Each argument is pre-built by the caller in the same shape that + ``dbmetadata[kind][schema]`` uses internally. Replacing the + per-schema dicts by assignment (rather than appending to the live + structures) keeps concurrent readers of ``get_completions`` safe. + """ + if not schema: + return + self.dbmetadata["tables"][schema] = table_columns + self.dbmetadata["views"].setdefault(schema, {}) + self.dbmetadata["functions"][schema] = functions + self.dbmetadata["procedures"][schema] = procedures + self.dbmetadata["enum_values"][schema] = enum_values + self.dbmetadata["foreign_keys"][schema] = foreign_keys + self._register_schema_completions(schema, table_columns, functions) + + def copy_other_schemas_from(self, source: "SQLCompleter", exclude: str | None) -> None: + """Copy per-schema metadata from *source*, skipping *exclude*. + + After a completion refresh swaps in a fresh completer that was + populated only with the current schema's data, this restores any + previously-loaded metadata for other schemas so the user can keep + using qualified completions (``OtherSchema.table``) without a + re-fetch. + """ + kinds = ("tables", "views", "functions", "procedures", "enum_values", "foreign_keys") + for kind in kinds: + src_map = source.dbmetadata.get(kind, {}) + dest_map = self.dbmetadata.setdefault(kind, {}) + for schema_name, data in src_map.items(): + if not schema_name or schema_name == exclude: + continue + if schema_name in dest_map: + continue + dest_map[schema_name] = data + for schema_name, table_columns in self.dbmetadata["tables"].items(): + if schema_name == exclude: + continue + functions = self.dbmetadata.get("functions", {}).get(schema_name, {}) + self._register_schema_completions(schema_name, table_columns, functions) + + def _register_schema_completions( + self, + schema: str, + table_columns: dict[str, list[str]], + functions: dict[str, None] | dict[str, Any], + ) -> None: + self.all_completions.add(schema) + for table, cols in table_columns.items(): + self.all_completions.add(table) + for col in cols: + if col != "*": + self.all_completions.add(col) + for func_name in functions: + self.all_completions.add(func_name) + + def reset_completions(self) -> None: + self.databases: list[str] = [] + self.users: list[str] = [] + self.character_sets: list[str] = [] + self.collations: list[str] = [] + self.show_items: list[Completion] = [] + self.dbname = "" + self.dbmetadata: dict[str, Any] = { + "tables": {}, + "views": {}, + "functions": {}, + "procedures": {}, + "enum_values": {}, + "foreign_keys": {}, + } self.all_completions = set(self.keywords + self.functions) - @staticmethod - def find_matches(text, collection, start_only=False, fuzzy=True, casing=None): + def maybe_quote_identifier(self, item: str) -> str: + if item.startswith('`'): + return item + if item == '*': + return item + return '`' + item + '`' + + def quote_collection_if_needed( + self, + text: str, + collection: Collection[Any], + text_before_cursor: str, + ) -> Collection[Any]: + # checking text.startswith() first is an optimization; is_inside_quotes() covers more cases + if text.startswith('`') or is_inside_quotes(text_before_cursor, len(text_before_cursor)) == 'backtick': + return [self.maybe_quote_identifier(x) if isinstance(x, str) else x for x in collection] + return collection + + def word_parts_match( + self, + text_parts: list[str], + item_parts: list[str], + ) -> bool: + occurrences = 0 + for text_part in text_parts: + for item_part in item_parts: + if item_part.startswith(text_part): + occurrences += 1 + break + return occurrences >= len(text_parts) + + def find_fuzzy_match( + self, + item: str, + pattern: re.Pattern[str], + under_words_text: list[str], + case_words_text: list[str], + ) -> int | None: + if pattern.search(item.lower()): + return Fuzziness.REGEX + + under_words_item = [x for x in item.lower().split('_') if x] + if self.word_parts_match(under_words_text, under_words_item): + return Fuzziness.UNDER_WORDS + + case_words_item = re.split(_CASE_CHANGE_PAT, item) + if self.word_parts_match(case_words_text, case_words_item): + return Fuzziness.CAMEL_CASE + + return None + + def find_fuzzy_matches( + self, + last: str, + text: str, + collection: Collection[Any], + ) -> list[tuple[str, int]]: + completions: list[tuple[str, int]] = [] + regex = '.{0,3}?'.join(map(re.escape, text)) + pattern = re.compile(f'({regex})') + under_words_text = [x for x in text.split('_') if x] + case_words_text = re.split(_CASE_CHANGE_PAT, last) + + for item in collection: + fuzziness = self.find_fuzzy_match(item, pattern, under_words_text, case_words_text) + if fuzziness is not None: + completions.append((item, fuzziness)) + + if len(text) >= 4: + rapidfuzz_matches = rapidfuzz.process.extract( + text, + collection, + scorer=rapidfuzz.fuzz.WRatio, + # todo: maybe make our own processor which only does case-folding + # because underscores are valuable info + processor=rapidfuzz.utils.default_process, + limit=20, + score_cutoff=75, + ) + existing = {c[0] for c in completions} + for item, _score, _type in rapidfuzz_matches: + if len(item) < len(text) / 1.5 or item in existing: + continue + completions.append((item, Fuzziness.RAPIDFUZZ)) + + return completions + + def find_perfect_matches( + self, + text: str, + collection: Collection[Any], + start_only: bool, + ) -> list[tuple[str, int]]: + completions: list[tuple[str, int]] = [] + match_end_limit = len(text) if start_only else None + for item in collection: + match_point = item.lower().find(text, 0, match_end_limit) + if match_point >= 0: + completions.append((item, Fuzziness.PERFECT)) + return completions + + def resolve_casing( + self, + casing: str | None, + last: str, + ) -> str | None: + if casing != 'auto': + return casing + return 'lower' if last and (last[0].islower() or last[-1].islower()) else 'upper' + + def apply_casing( + self, + completions: list[tuple[str, int]], + casing: str | None, + ) -> Generator[tuple[str, int], None, None]: + if casing is None: + return (completion for completion in completions) + + def apply_case(tup: tuple[str, int]) -> tuple[str, int]: + kw, fuzziness = tup + if casing == 'upper': + return (kw.upper(), fuzziness) + return (kw.lower(), fuzziness) + + return (apply_case(completion) for completion in completions) + + def find_matches( + self, + orig_text: str, + collection: Collection, + start_only: bool = False, + fuzzy: bool = True, + casing: str | None = None, + text_before_cursor: str = '', + ) -> Generator[tuple[str, int], None, None]: """Find completion matches for the given text. Given the user's input text and a collection of available @@ -369,174 +1388,362 @@ def find_matches(text, collection, start_only=False, fuzzy=True, casing=None): yields prompt_toolkit Completion instances for any matches found in the collection of available completions. """ - last = last_word(text, include='most_punctuations') + last = last_word(orig_text, include='most_punctuations') text = last.lower() - - completions = [] + quoted_collection = self.quote_collection_if_needed(text, collection, text_before_cursor) if fuzzy: - regex = '.*?'.join(map(escape, text)) - pat = compile('(%s)' % regex) - for item in collection: - r = pat.search(item.lower()) - if r: - completions.append((len(r.group()), r.start(), item)) + completions = self.find_fuzzy_matches(last, text, quoted_collection) else: - match_end_limit = len(text) if start_only else None - for item in collection: - match_point = item.lower().find(text, 0, match_end_limit) - if match_point >= 0: - completions.append((len(text), match_point, item)) - - if casing == 'auto': - casing = 'lower' if last and last[-1].islower() else 'upper' + completions = self.find_perfect_matches(text, quoted_collection, start_only) - def apply_case(kw): - if casing == 'upper': - return kw.upper() - return kw.lower() + casing = self.resolve_casing(casing, last) + return self.apply_casing(completions, casing) - return (Completion(z if casing is None else apply_case(z), -len(text)) - for x, y, z in completions) - - def get_completions(self, document, complete_event, smart_completion=None): + def get_completions( + self, + document: Document, + complete_event: CompleteEvent | None, + smart_completion: bool | None = None, + ) -> Iterable[Completion]: word_before_cursor = document.get_word_before_cursor(WORD=True) + last_for_len = last_word(word_before_cursor, include="most_punctuations") + text_for_len = last_for_len.lower() + last_for_len_paths = last_word(word_before_cursor, include='alphanum_underscore') + if smart_completion is None: smart_completion = self.smart_completion # If smart_completion is off then match any word that starts with # 'word_before_cursor'. if not smart_completion: - return self.find_matches(word_before_cursor, self.all_completions, - start_only=True, fuzzy=False) - - completions = [] + matches = self.find_matches( + word_before_cursor, + self.all_completions, + start_only=True, + fuzzy=False, + text_before_cursor=document.text_before_cursor, + ) + return (Completion(x[0], -len(text_for_len)) for x in matches) + + completions: list[tuple[str, int, int]] = [] suggestions = suggest_type(document.text, document.text_before_cursor) + rigid_sort = False + length_based_on_path = False + rank = 0 for suggestion in suggestions: + _logger.debug("Suggestion type: %r", suggestion["type"]) + rank += 1 - _logger.debug('Suggestion type: %r', suggestion['type']) - - if suggestion['type'] == 'column': - tables = suggestion['tables'] + if suggestion["type"] == "column": + tables = suggestion["tables"] _logger.debug("Completion column scope: %r", tables) scoped_cols = self.populate_scoped_cols(tables) - if suggestion.get('drop_unique'): + if suggestion.get("drop_unique"): # drop_unique is used for 'tb11 JOIN tbl2 USING (...' # which should suggest only columns that appear in more than # one table - scoped_cols = [ - col for (col, count) in Counter(scoped_cols).items() - if count > 1 and col != '*' - ] - - cols = self.find_matches(word_before_cursor, scoped_cols) - completions.extend(cols) - - elif suggestion['type'] == 'function': + scoped_cols = [col for (col, count) in Counter(scoped_cols).items() if count > 1 and col != "*"] + elif not tables: + # if tables was empty, this is a naked SELECT and we are + # showing all columns. So make them unique and sort them. + scoped_cols = sorted(set(scoped_cols), key=lambda s: s.strip('`')) + + cols = self.find_matches( + word_before_cursor, + scoped_cols, + text_before_cursor=document.text_before_cursor, + ) + completions.extend([(*x, rank) for x in cols]) + + elif suggestion["type"] == "function": # suggest user-defined functions using substring matching - funcs = self.populate_schema_objects(suggestion['schema'], - 'functions') - user_funcs = self.find_matches(word_before_cursor, funcs) - completions.extend(user_funcs) + funcs = self.populate_schema_objects(suggestion["schema"], "functions") + user_funcs = self.find_matches( + word_before_cursor, + funcs, + text_before_cursor=document.text_before_cursor, + ) + completions.extend([(*x, rank) for x in user_funcs]) # suggest hardcoded functions using startswith matching only if # there is no schema qualifier. If a schema qualifier is # present it probably denotes a table. # eg: SELECT * FROM users u WHERE u. - if not suggestion['schema']: - predefined_funcs = self.find_matches(word_before_cursor, - self.functions, - start_only=True, - fuzzy=False, - casing=self.keyword_casing) - completions.extend(predefined_funcs) - - elif suggestion['type'] == 'table': - tables = self.populate_schema_objects(suggestion['schema'], - 'tables') - tables = self.find_matches(word_before_cursor, tables) - completions.extend(tables) - - elif suggestion['type'] == 'view': - views = self.populate_schema_objects(suggestion['schema'], - 'views') - views = self.find_matches(word_before_cursor, views) - completions.extend(views) - - elif suggestion['type'] == 'alias': - aliases = suggestion['aliases'] - aliases = self.find_matches(word_before_cursor, aliases) - completions.extend(aliases) - - elif suggestion['type'] == 'database': - dbs = self.find_matches(word_before_cursor, self.databases) - completions.extend(dbs) - - elif suggestion['type'] == 'keyword': - keywords = self.find_matches(word_before_cursor, self.keywords, - start_only=True, - fuzzy=False, - casing=self.keyword_casing) - completions.extend(keywords) - - elif suggestion['type'] == 'show': - show_items = self.find_matches(word_before_cursor, - self.show_items, - start_only=False, - fuzzy=True, - casing=self.keyword_casing) - completions.extend(show_items) - - elif suggestion['type'] == 'change': - change_items = self.find_matches(word_before_cursor, - self.change_items, - start_only=False, - fuzzy=True) - completions.extend(change_items) - elif suggestion['type'] == 'user': - users = self.find_matches(word_before_cursor, self.users, - start_only=False, - fuzzy=True) - completions.extend(users) - - elif suggestion['type'] == 'special': - special = self.find_matches(word_before_cursor, - self.special_commands, - start_only=True, - fuzzy=False) - completions.extend(special) - elif suggestion['type'] == 'favoritequery': - queries = self.find_matches(word_before_cursor, - FavoriteQueries.instance.list(), - start_only=False, fuzzy=True) - completions.extend(queries) - elif suggestion['type'] == 'table_format': - formats = self.find_matches(word_before_cursor, - self.table_formats, - start_only=True, fuzzy=False) - completions.extend(formats) - elif suggestion['type'] == 'file_name': - file_names = self.find_files(word_before_cursor) - completions.extend(file_names) + if not suggestion["schema"]: + predefined_funcs = self.find_matches( + word_before_cursor, + self.functions, + start_only=True, + fuzzy=False, + casing=self.keyword_casing, + text_before_cursor=document.text_before_cursor, + ) + completions.extend([(*x, rank) for x in predefined_funcs]) + + elif suggestion["type"] == "procedure": + procs = self.populate_schema_objects(suggestion["schema"], "procedures") + procs_m = self.find_matches( + word_before_cursor, + procs, + text_before_cursor=document.text_before_cursor, + ) + completions.extend([(*x, rank) for x in procs_m]) + + elif suggestion['type'] == 'introducer': + introducers = [f'_{x}' for x in self.character_sets] + introducers_m = self.find_matches( + word_before_cursor, + introducers, + text_before_cursor=document.text_before_cursor, + ) + completions.extend([(*x, rank) for x in introducers_m]) + + elif suggestion['type'] == 'character_set': + charsets_m = self.find_matches( + word_before_cursor, + self.character_sets, + text_before_cursor=document.text_before_cursor, + ) + completions.extend([(*x, rank) for x in charsets_m]) + + elif suggestion['type'] == 'collation': + collations_m = self.find_matches( + word_before_cursor, + self.collations, + text_before_cursor=document.text_before_cursor, + ) + completions.extend([(*x, rank) for x in collations_m]) + + elif suggestion["type"] == "table": + # If this is a select and columns are given, parse the columns and + # then only return tables that have one or more of the given columns. + # If no columns are given (or able to be parsed), return all tables + # as usual. + columns = extract_columns_from_select(document.text) + if columns: + tables = self.populate_schema_objects(suggestion["schema"], "tables", columns) + else: + tables = self.populate_schema_objects(suggestion["schema"], "tables") + + if suggestion.get("join"): + # For JOINs, suggest FK-related tables first (lower rank = higher priority) + current_tables = extract_tables(document.text) + fk_map = self.dbmetadata["foreign_keys"].get(self.dbname, {}).get("tables", {}) + fk_related: set[str] = set() + for tbl_schema, tbl, _alias in current_tables: + # Skip cross-schema tables; FK metadata is only for the current db + if tbl_schema and tbl_schema != self.dbname: + continue + escaped = self.escape_name(tbl) + fk_related.update(fk_map.get(escaped, set())) + fk_tables = [t for t in tables if t in fk_related] + other_tables = [t for t in tables if t not in fk_related] + fk_tables_m = self.find_matches( + word_before_cursor, + fk_tables, + text_before_cursor=document.text_before_cursor, + ) + other_tables_m = self.find_matches( + word_before_cursor, + other_tables, + text_before_cursor=document.text_before_cursor, + ) + completions.extend([(*x, rank) for x in fk_tables_m]) + completions.extend([(*x, rank + 1) for x in other_tables_m]) + else: + tables_m = self.find_matches( + word_before_cursor, + tables, + text_before_cursor=document.text_before_cursor, + ) + completions.extend([(*x, rank) for x in tables_m]) + + elif suggestion["type"] == "view": + views = self.populate_schema_objects(suggestion["schema"], "views") + views_m = self.find_matches( + word_before_cursor, + views, + text_before_cursor=document.text_before_cursor, + ) + completions.extend([(*x, rank) for x in views_m]) + + elif suggestion["type"] == "fk_join": + fk_conditions = self._fk_join_conditions(suggestion["tables"]) + fk_conditions_m = self.find_matches( + word_before_cursor, + fk_conditions, + text_before_cursor=document.text_before_cursor, + ) + completions.extend([(*x, rank) for x in fk_conditions_m]) + + elif suggestion["type"] == "alias": + aliases = suggestion["aliases"] + aliases_m = self.find_matches( + word_before_cursor, + aliases, + text_before_cursor=document.text_before_cursor, + ) + completions.extend([(*x, rank) for x in aliases_m]) + + elif suggestion["type"] == "database": + dbs_m = self.find_matches( + word_before_cursor, + self.databases, + text_before_cursor=document.text_before_cursor, + ) + completions.extend([(*x, rank) for x in dbs_m]) + + elif suggestion["type"] == "keyword": + keywords_m = self.find_matches( + word_before_cursor, + self.keywords, + casing=self.keyword_casing, + text_before_cursor=document.text_before_cursor, + ) + completions.extend([(*x, rank) for x in keywords_m]) + + elif suggestion["type"] == "show": + show_items_m = self.find_matches( + word_before_cursor, + self.show_items, + start_only=False, + fuzzy=True, + casing=self.keyword_casing, + text_before_cursor=document.text_before_cursor, + ) + completions.extend([(*x, rank) for x in show_items_m]) + + elif suggestion["type"] == "change": + change_items_m = self.find_matches( + word_before_cursor, + self.change_items, + start_only=False, + fuzzy=True, + text_before_cursor=document.text_before_cursor, + ) + completions.extend([(*x, rank) for x in change_items_m]) + + elif suggestion["type"] == "user": + users_m = self.find_matches( + word_before_cursor, + self.users, + start_only=False, + fuzzy=True, + text_before_cursor=document.text_before_cursor, + ) + completions.extend([(*x, rank) for x in users_m]) + + elif suggestion["type"] == "special": + special_m = self.find_matches( + word_before_cursor, + self.special_commands, + start_only=True, + fuzzy=False, + text_before_cursor=document.text_before_cursor, + ) + # specials are special, and go early in the candidates, first if possible + completions.extend([(*x, 0) for x in special_m]) + + elif suggestion["type"] == "favoritequery": + if hasattr(FavoriteQueries, 'instance') and hasattr(FavoriteQueries.instance, 'list'): + queries_m = self.find_matches( + word_before_cursor, + FavoriteQueries.instance.list(), + start_only=False, + fuzzy=True, + text_before_cursor=document.text_before_cursor, + ) + completions.extend([(*x, rank) for x in queries_m]) + + elif suggestion["type"] == "table_format": + formats_m = self.find_matches( + word_before_cursor, + self.table_formats, + text_before_cursor=document.text_before_cursor, + ) + completions.extend([(*x, rank) for x in formats_m]) + + elif suggestion["type"] == "file_name": + file_names_m = self.find_files(word_before_cursor) + completions.extend([(*x, rank) for x in file_names_m]) + # for filenames we _really_ want directories to go last + rigid_sort = True + length_based_on_path = True + elif suggestion["type"] == "llm": + if not word_before_cursor: + tokens = document.text.split()[1:] + else: + tokens = document.text.split()[1:-1] + possible_entries = llm.get_completions(tokens) + subcommands_m = self.find_matches( + word_before_cursor, + possible_entries, + start_only=False, + fuzzy=True, + text_before_cursor=document.text_before_cursor, + ) + completions.extend([(*x, rank) for x in subcommands_m]) + + elif suggestion["type"] == "enum_value": + enum_values = self.populate_enum_values( + suggestion["tables"], + suggestion["column"], + suggestion.get("parent"), + ) + if enum_values: + quoted_values = [self._quote_sql_string(value) for value in enum_values] + completions = [ + (*x, rank) + for x in self.find_matches( + word_before_cursor, + quoted_values, + text_before_cursor=document.text_before_cursor, + ) + ] + break + + def completion_sort_key(item: tuple[str, int, int], text_for_len: str): + candidate, fuzziness, rank = item + if not text_for_len: + # sort only by the rank (the order of the completion type) + return (0, rank, 0) + elif candidate.lower().startswith(text_for_len): + # sort only by the length of the candidate + return (0, 0, -1000 + len(candidate)) + # sort by fuzziness and rank + # todo add alpha here, or original order? + return (fuzziness, rank, 0) + + if rigid_sort: + uniq_completions_str = dict.fromkeys(x[0] for x in completions) + else: + sorted_completions = sorted(completions, key=lambda item: completion_sort_key(item, text_for_len.lower())) + uniq_completions_str = dict.fromkeys(x[0] for x in sorted_completions) - return completions + if length_based_on_path: + return (Completion(x, -len(last_for_len_paths)) for x in uniq_completions_str) + else: + return (Completion(x, -len(text_for_len)) for x in uniq_completions_str) - def find_files(self, word): + def find_files(self, word: str) -> Generator[tuple[str, int], None, None]: """Yield matching directory or file names. :param word: :return: iterable """ + # todo position is ignored, but may need to be used + # todo fuzzy matches for filenames base_path, last_path, position = parse_path(word) paths = suggest_path(word) - for name in sorted(paths): + for name in paths: suggestion = complete_path(name, last_path) if suggestion: - yield Completion(suggestion, position) + yield (suggestion, Fuzziness.PERFECT) - def populate_scoped_cols(self, scoped_tbls): + def populate_scoped_cols(self, scoped_tbls: list[tuple[str | None, str, str | None]]) -> list[str]: """Find all columns in a set of scoped_tables :param scoped_tbls: list of (schema, table, alias) tuples :return: list of column names @@ -544,6 +1751,14 @@ def populate_scoped_cols(self, scoped_tbls): columns = [] meta = self.dbmetadata + # if scoped tables is empty, this is just after a SELECT so we + # show all columns for all tables in the schema. + if len(scoped_tbls) == 0 and self.dbname: + for table in meta["tables"][self.dbname]: + columns.extend(meta["tables"][self.dbname][table]) + return columns or ['*'] + + # query includes tables, so use those to populate columns for tbl in scoped_tbls: # A fully qualified schema.relname reference or default_schema # DO NOT escape schema names. @@ -555,34 +1770,102 @@ def populate_scoped_cols(self, scoped_tbls): # tables and views cannot share the same name, we can check one # at a time try: - columns.extend(meta['tables'][schema][relname]) + columns.extend(meta["tables"][schema][relname]) # Table exists, so don't bother checking for a view continue except KeyError: try: - columns.extend(meta['tables'][schema][escaped_relname]) + columns.extend(meta["tables"][schema][escaped_relname]) # Table exists, so don't bother checking for a view continue except KeyError: pass try: - columns.extend(meta['views'][schema][relname]) + columns.extend(meta["views"][schema][relname]) except KeyError: pass return columns - def populate_schema_objects(self, schema, obj_type): + def populate_enum_values( + self, + scoped_tbls: list[tuple[str | None, str, str | None]], + column: str, + parent: str | None = None, + ) -> list[str]: + values: list[str] = [] + meta = self.dbmetadata["enum_values"] + column_key = self._escape_identifier(column) + parent_key = self._strip_backticks(parent) if parent else None + + for schema, relname, alias in scoped_tbls: + if parent_key and not self._matches_parent(parent_key, schema, relname, alias): + continue + + schema = schema or self.dbname + table_meta = meta.get(schema, {}) + escaped_relname = self.escape_name(relname) + + for rel_key in {relname, escaped_relname}: + columns = table_meta.get(rel_key) + if columns and column_key in columns: + values.extend(columns[column_key]) + + return list(dict.fromkeys(values)) + + def _escape_identifier(self, name: str) -> str: + return self.escape_name(self._strip_backticks(name)) + + @staticmethod + def _strip_backticks(name: str | None) -> str: + if name and name[0] == "`" and name[-1] == "`": + return name[1:-1] + return name or "" + + @staticmethod + def _matches_parent(parent: str, schema: str | None, relname: str, alias: str | None) -> bool: + if alias and parent == alias: + return True + if parent == relname: + return True + if schema and parent == f"{schema}.{relname}": + return True + return False + + @staticmethod + def _quote_sql_string(value: str) -> str: + return "'" + value.replace("'", "''") + "'" + + def populate_schema_objects(self, schema: str | None, obj_type: str, columns: list[str] | None = None) -> list[str]: """Returns list of tables or functions for a (optional) schema""" metadata = self.dbmetadata[obj_type] schema = schema or self.dbname - try: - objects = metadata[schema].keys() + objects = list(metadata[schema].keys()) except KeyError: # schema doesn't exist objects = [] - return objects + filtered_objects: list[str] = [] + remaining_objects: list[str] = [] + + # If the requested object type is tables and the user already entered + # columns, return a filtered list of tables (or views) that contain + # one or more of the given columns. If a table does not contain the + # given columns, add it to a separate list to add to the end of the + # filtered suggestions. + if obj_type == "tables" and columns and objects: + for obj in objects: + matched = False + for column in metadata[schema][obj]: + if column in columns: + filtered_objects.append(obj) + matched = True + break + if not matched: + remaining_objects.append(obj) + else: + filtered_objects = objects + return filtered_objects + remaining_objects diff --git a/mycli/sqlexecute.py b/mycli/sqlexecute.py index bd5f5d98..ecf975ff 100644 --- a/mycli/sqlexecute.py +++ b/mycli/sqlexecute.py @@ -1,170 +1,276 @@ +from __future__ import annotations + +import datetime import enum import logging import re +import ssl +from typing import Any, Generator, Iterable +from prompt_toolkit.formatted_text import FormattedText import pymysql -from .packages import special +from pymysql.connections import Connection from pymysql.constants import FIELD_TYPE -from pymysql.converters import (convert_datetime, - convert_timedelta, convert_date, conversions, - decoders) +from pymysql.converters import conversions, convert_date, convert_datetime, convert_time, decoders +from pymysql.cursors import Cursor + +from mycli.constants import ER_MUST_CHANGE_PASSWORD +from mycli.packages.special import iocommands +from mycli.packages.special.main import CommandNotFound, execute +from mycli.packages.sqlresult import SQLResult + try: - import paramiko + import paramiko # noqa: F401 + import sshtunnel except ImportError: - from mycli.packages.paramiko_stub import paramiko + pass _logger = logging.getLogger(__name__) FIELD_TYPES = decoders.copy() -FIELD_TYPES.update({ - FIELD_TYPE.NULL: type(None) -}) +FIELD_TYPES.update({FIELD_TYPE.NULL: type(None)}) ERROR_CODE_ACCESS_DENIED = 1045 class ServerSpecies(enum.Enum): - MySQL = 'MySQL' - MariaDB = 'MariaDB' - Percona = 'Percona' - TiDB = 'TiDB' - Unknown = 'MySQL' + MySQL = "MySQL" + MariaDB = "MariaDB" + Percona = "Percona" + TiDB = "TiDB" + Unknown = "Unknown" class ServerInfo: - def __init__(self, species, version_str): + def __init__(self, species: ServerSpecies | None, version_str: str) -> None: self.species = species self.version_str = version_str self.version = self.calc_mysql_version_value(version_str) @staticmethod - def calc_mysql_version_value(version_str) -> int: + def calc_mysql_version_value(version_str: str) -> int: if not version_str or not isinstance(version_str, str): return 0 try: - major, minor, patch = version_str.split('.') + major, minor, patch = version_str.split(".") except ValueError: return 0 else: return int(major) * 10_000 + int(minor) * 100 + int(patch) @classmethod - def from_version_string(cls, version_string): + def from_version_string(cls, version_string: str) -> ServerInfo: if not version_string: - return cls(ServerSpecies.Unknown, '') + return cls(ServerSpecies.MySQL, "") re_species = ( - (r'(?P[0-9\.]+)-MariaDB', ServerSpecies.MariaDB), - (r'[0-9\.]*-TiDB-v(?P[0-9\.]+)-?(?P[a-z0-9\-]*)', ServerSpecies.TiDB), - (r'(?P[0-9\.]+)[a-z0-9]*-(?P[0-9]+$)', - ServerSpecies.Percona), - (r'(?P[0-9\.]+)[a-z0-9]*-(?P[A-Za-z0-9_]+)', - ServerSpecies.MySQL), + (r"(?P[0-9\.]+)-MariaDB", ServerSpecies.MariaDB), + (r"[0-9\.]*-TiDB-v(?P[0-9\.]+)-?(?P[a-z0-9\-]*)", ServerSpecies.TiDB), + (r"(?P[0-9\.]+)[a-z0-9]*-(?P[0-9]+$)", ServerSpecies.Percona), + (r"(?P[0-9\.]+)[a-z0-9]*-(?P[A-Za-z0-9_]+)", ServerSpecies.MySQL), ) for regexp, species in re_species: match = re.search(regexp, version_string) if match is not None: - parsed_version = match.group('version') + parsed_version = match.group("version") detected_species = species break else: - detected_species = ServerSpecies.Unknown - parsed_version = '' + detected_species = ServerSpecies.MySQL + parsed_version = "" return cls(detected_species, parsed_version) - def __str__(self): + def __str__(self) -> str: if self.species: - return f'{self.species.value} {self.version_str}' + return f"{self.species.value} {self.version_str}" else: return self.version_str -class SQLExecute(object): - - databases_query = '''SHOW DATABASES''' +class SQLExecute: + databases_query = """SHOW DATABASES""" - tables_query = '''SHOW TABLES''' + tables_query = """SHOW TABLES""" show_candidates_query = '''SELECT name from mysql.help_topic WHERE name like "SHOW %"''' - users_query = '''SELECT CONCAT("'", user, "'@'",host,"'") FROM mysql.user''' + users_query = """SELECT CONCAT("'", user, "'@'",host,"'") FROM mysql.user""" functions_query = '''SELECT ROUTINE_NAME FROM INFORMATION_SCHEMA.ROUTINES - WHERE ROUTINE_TYPE="FUNCTION" AND ROUTINE_SCHEMA = "%s"''' + WHERE ROUTINE_TYPE="FUNCTION" AND ROUTINE_SCHEMA = %s''' + + procedures_query = '''SELECT ROUTINE_NAME FROM INFORMATION_SCHEMA.ROUTINES + WHERE ROUTINE_TYPE="PROCEDURE" AND ROUTINE_SCHEMA = %s''' + + character_sets_query = '''SHOW CHARACTER SET''' + + collations_query = '''SHOW COLLATION''' - table_columns_query = '''select TABLE_NAME, COLUMN_NAME from information_schema.columns - where table_schema = '%s' - order by table_name,ordinal_position''' + table_columns_query = """select TABLE_NAME, COLUMN_NAME from information_schema.columns + where table_schema = %s + order by table_name,ordinal_position""" - def __init__(self, database, user, password, host, port, socket, charset, - local_infile, ssl, ssh_user, ssh_host, ssh_port, ssh_password, - ssh_key_filename, init_command=None): + enum_values_query = """select TABLE_NAME, COLUMN_NAME, COLUMN_TYPE from information_schema.columns + where table_schema = %s and data_type = 'enum' + order by table_name,ordinal_position""" + + foreign_keys_query = """SELECT TABLE_NAME, COLUMN_NAME, REFERENCED_TABLE_NAME, REFERENCED_COLUMN_NAME + FROM information_schema.KEY_COLUMN_USAGE + WHERE TABLE_SCHEMA = %s AND REFERENCED_TABLE_NAME IS NOT NULL""" + + now_query = """SELECT NOW()""" + + @staticmethod + def _parse_enum_values(column_type: str) -> list[str]: + if not column_type or not column_type.lower().startswith("enum("): + return [] + + values: list[str] = [] + current: list[str] = [] + in_quote = False + i = column_type.find("(") + 1 + + while i < len(column_type): + ch = column_type[i] + + if not in_quote: + if ch == "'": + in_quote = True + current = [] + elif ch == ")": + break + else: + if ch == "\\" and i + 1 < len(column_type): + current.append(column_type[i + 1]) + i += 1 + elif ch == "'": + if i + 1 < len(column_type) and column_type[i + 1] == "'": + current.append("'") + i += 1 + else: + values.append("".join(current)) + in_quote = False + else: + current.append(ch) + i += 1 + + return values + + def __init__( + self, + database: str | None, + user: str | None, + password: str | None, + host: str | None, + port: int | None, + socket: str | None, + character_set: str | None, + local_infile: bool | None, + ssl: dict[str, Any] | None, + ssh_user: str | None, + ssh_host: str | None, + ssh_port: int | None, + ssh_password: str | None, + ssh_key_filename: str | None, + init_command: str | None = None, + unbuffered: bool | None = None, + ) -> None: self.dbname = database self.user = user self.password = password self.host = host self.port = port self.socket = socket - self.charset = charset + self.character_set = character_set self.local_infile = local_infile self.ssl = ssl - self.server_info = None - self.connection_id = None + self.server_info: ServerInfo | None = None + self.connection_id: int | None = None self.ssh_user = ssh_user self.ssh_host = ssh_host self.ssh_port = ssh_port self.ssh_password = ssh_password self.ssh_key_filename = ssh_key_filename self.init_command = init_command + self.unbuffered = unbuffered + self.conn: Connection | None = None self.connect() - def connect(self, database=None, user=None, password=None, host=None, - port=None, socket=None, charset=None, local_infile=None, - ssl=None, ssh_host=None, ssh_port=None, ssh_user=None, - ssh_password=None, ssh_key_filename=None, init_command=None): - db = (database or self.dbname) - user = (user or self.user) - password = (password or self.password) - host = (host or self.host) - port = (port or self.port) - socket = (socket or self.socket) - charset = (charset or self.charset) - local_infile = (local_infile or self.local_infile) - ssl = (ssl or self.ssl) - ssh_user = (ssh_user or self.ssh_user) - ssh_host = (ssh_host or self.ssh_host) - ssh_port = (ssh_port or self.ssh_port) - ssh_password = (ssh_password or self.ssh_password) - ssh_key_filename = (ssh_key_filename or self.ssh_key_filename) - init_command = (init_command or self.init_command) + def connect( + self, + database: str | None = None, + user: str | None = None, + password: str | None = None, + host: str | None = None, + port: int | None = None, + socket: str | None = None, + character_set: str | None = None, + local_infile: bool | None = None, + ssl: dict[str, Any] | None = None, + ssh_host: str | None = None, + ssh_port: int | None = None, + ssh_user: str | None = None, + ssh_password: str | None = None, + ssh_key_filename: str | None = None, + init_command: str | None = None, + unbuffered: bool | None = None, + ): + db = database if database is not None else self.dbname + user = user if user is not None else self.user + password = password if password is not None else self.password + host = host if host is not None else self.host + port = port if port is not None else self.port + socket = socket if socket is not None else self.socket + character_set = character_set if character_set is not None else self.character_set + local_infile = local_infile if local_infile is not None else self.local_infile + ssl = ssl if ssl is not None else self.ssl + ssh_user = ssh_user if ssh_user is not None else self.ssh_user + ssh_host = ssh_host if ssh_host is not None else self.ssh_host + ssh_port = ssh_port if ssh_port is not None else self.ssh_port + ssh_password = ssh_password if ssh_password is not None else self.ssh_password + ssh_key_filename = ssh_key_filename if ssh_key_filename is not None else self.ssh_key_filename + init_command = init_command if init_command is not None else self.init_command + unbuffered = unbuffered if unbuffered is not None else self.unbuffered _logger.debug( - 'Connection DB Params: \n' - '\tdatabase: %r' - '\tuser: %r' - '\thost: %r' - '\tport: %r' - '\tsocket: %r' - '\tcharset: %r' - '\tlocal_infile: %r' - '\tssl: %r' - '\tssh_user: %r' - '\tssh_host: %r' - '\tssh_port: %r' - '\tssh_password: %r' - '\tssh_key_filename: %r' - '\tinit_command: %r', - db, user, host, port, socket, charset, local_infile, ssl, - ssh_user, ssh_host, ssh_port, ssh_password, ssh_key_filename, - init_command + "Connection DB Params: \n" + "\tdatabase: %r" + "\tuser: %r" + "\thost: %r" + "\tport: %r" + "\tsocket: %r" + "\tcharacter_set: %r" + "\tlocal_infile: %r" + "\tssl: %r" + "\tssh_user: %r" + "\tssh_host: %r" + "\tssh_port: %r" + "\tssh_password: ***" + "\tssh_key_filename: %r" + "\tinit_command: %r" + "\tunbuffered: %r", + db, + user, + host, + port, + socket, + character_set, + local_infile, + ssl, + ssh_user, + ssh_host, + ssh_port, + ssh_key_filename, + init_command, + unbuffered, ) conv = conversions.copy() conv.update({ - FIELD_TYPE.TIMESTAMP: lambda obj: (convert_datetime(obj) or obj), - FIELD_TYPE.DATETIME: lambda obj: (convert_datetime(obj) or obj), - FIELD_TYPE.TIME: lambda obj: (convert_timedelta(obj) or obj), - FIELD_TYPE.DATE: lambda obj: (convert_date(obj) or obj), + FIELD_TYPE.TIMESTAMP: lambda obj: convert_datetime(obj) or obj, + FIELD_TYPE.DATETIME: lambda obj: convert_datetime(obj) or obj, + FIELD_TYPE.TIME: lambda obj: convert_time(obj) or obj, + FIELD_TYPE.DATE: lambda obj: convert_date(obj) or obj, }) defer_connect = False @@ -173,38 +279,76 @@ def connect(self, database=None, user=None, password=None, host=None, defer_connect = True client_flag = pymysql.constants.CLIENT.INTERACTIVE - if init_command and len(list(special.split_queries(init_command))) > 1: + if init_command and len(list(iocommands.split_queries(init_command))) > 1: client_flag |= pymysql.constants.CLIENT.MULTI_STATEMENTS + client_flag |= pymysql.constants.CLIENT.HANDLE_EXPIRED_PASSWORDS ssl_context = None if ssl: ssl_context = self._create_ssl_ctx(ssl) - conn = pymysql.connect( - database=db, user=user, password=password, host=host, port=port, - unix_socket=socket, use_unicode=True, charset=charset, - autocommit=True, client_flag=client_flag, - local_infile=local_infile, conv=conv, ssl=ssl_context, program_name="mycli", - defer_connect=defer_connect, init_command=init_command - ) + connect_kwargs: dict[str, Any] = { + "database": db, + "user": user, + "password": password or '', + "host": host, + "port": port or 0, + "unix_socket": socket, + "use_unicode": True, + "charset": character_set or '', + "autocommit": True, + "client_flag": client_flag, + "local_infile": local_infile or False, + "conv": conv, + "ssl": ssl_context, # type: ignore[arg-type] + "program_name": "mycli", + "defer_connect": defer_connect, + "init_command": init_command or None, + "cursorclass": pymysql.cursors.SSCursor if unbuffered else pymysql.cursors.Cursor, + } + + self.sandbox_mode = False + try: + conn = pymysql.connect(**connect_kwargs) # type: ignore[misc] + except pymysql.OperationalError as e: + if e.args[0] == ER_MUST_CHANGE_PASSWORD: + # Post-handshake queries (SET NAMES, SET AUTOCOMMIT, init_command) + # fail with ER_MUST_CHANGE_PASSWORD in sandbox mode. + # Reconnect with only the raw handshake. + connect_kwargs['defer_connect'] = True + connect_kwargs['autocommit'] = None + connect_kwargs['init_command'] = None + conn = pymysql.connect(**connect_kwargs) # type: ignore[misc] + self._connect_sandbox(conn) + self.sandbox_mode = True + else: + raise - if ssh_host: - client = paramiko.SSHClient() - client.load_system_host_keys() - client.set_missing_host_key_policy(paramiko.WarningPolicy()) - client.connect( - ssh_host, ssh_port, ssh_user, ssh_password, - key_filename=ssh_key_filename - ) - chan = client.get_transport().open_channel( - 'direct-tcpip', - (host, port), - ('0.0.0.0', 0), - ) - conn.connect(chan) - - if hasattr(self, 'conn'): - self.conn.close() + if ssh_host and not self.sandbox_mode: + ##### paramiko.Channel is a bad socket implementation overall if you want SSL through an SSH tunnel + ##### + # instead let's open a tunnel and rewrite host:port to local bind + try: + chan = sshtunnel.SSHTunnelForwarder( + (ssh_host, ssh_port), + ssh_username=ssh_user, + ssh_pkey=ssh_key_filename, + ssh_password=ssh_password, + remote_bind_address=(host, port), + ) + chan.start() + + conn.host = chan.local_bind_host + conn.port = chan.local_bind_port + conn.connect() + except Exception as e: + raise e + + if self.conn is not None: + try: + self.conn.close() + except pymysql.err.Error: + pass self.conn = conn # Update them after the connection is made to ensure that it was a # successful connection. @@ -214,45 +358,50 @@ def connect(self, database=None, user=None, password=None, host=None, self.host = host self.port = port self.socket = socket - self.charset = charset + self.character_set = character_set self.ssl = ssl self.init_command = init_command - # retrieve connection id - self.reset_connection_id() - self.server_info = ServerInfo.from_version_string(conn.server_version) - - def run(self, statement): - """Execute the sql in the database and return the results. The results - are a list of tuples. Each tuple has 4 values - (title, rows, headers, status). - """ + self.unbuffered = unbuffered + # retrieve connection id (skip in sandbox mode as queries will fail) + if not self.sandbox_mode: + self.reset_connection_id() + self.server_info = ServerInfo.from_version_string(conn.server_version) # type: ignore[attr-defined] + + def run(self, statement: str) -> Generator[SQLResult, None, None]: + """Execute the sql in the database and return the results.""" # Remove spaces and EOL statement = statement.strip() if not statement: # Empty string - yield (None, None, None, None) + yield SQLResult() # Split the sql into separate queries and run each one. # Unless it's saving a favorite query, in which case we # want to save them all together. - if statement.startswith('\\fs'): - components = [statement] + if statement.startswith("\\fs"): + components: Iterable[str] = [statement] else: - components = special.split_queries(statement) + components = iocommands.split_queries(statement) for sql in components: # \G is treated specially since we have to set the expanded output. - if sql.endswith('\\G'): - special.set_expanded_output(True) + if sql.endswith("\\G"): + iocommands.set_expanded_output(True) + sql = sql[:-2].strip() + # \g is treated specially since we might want collapsed output when + # auto vertical output is enabled + elif sql.endswith('\\g'): + iocommands.set_expanded_output(False) + iocommands.set_forced_horizontal_output(True) sql = sql[:-2].strip() + assert isinstance(self.conn, Connection) cur = self.conn.cursor() - try: # Special command - _logger.debug('Trying a dbspecial command. sql: %r', sql) - for result in special.execute(cur, sql): - yield result - except special.CommandNotFound: # Regular SQL - _logger.debug('Regular sql statement. sql: %r', sql) + try: # Special command + _logger.debug("Trying a dbspecial command. sql: %r", sql) + yield from execute(cur, sql) + except CommandNotFound: # Regular SQL + _logger.debug("Regular sql statement. sql: %r", sql) cur.execute(sql) while True: yield self.get_result(cur) @@ -263,105 +412,214 @@ def run(self, statement): if not cur.nextset() or (not cur.rowcount and cur.description is None): break - def get_result(self, cursor): + def get_result(self, cursor: Cursor) -> SQLResult: """Get the current result's data from the cursor.""" - title = headers = None + preamble = header = None # cursor.description is not None for queries that return result sets, # e.g. SELECT or SHOW. - if cursor.description is not None: - headers = [x[0] for x in cursor.description] - status = '{0} row{1} in set' + plural = '' if cursor.rowcount == 1 else 's' + if cursor.description: + header = [x[0] for x in cursor.description] + status = FormattedText([('', f'{cursor.rowcount} row{plural} in set')]) else: - _logger.debug('No rows in result.') - status = 'Query OK, {0} row{1} affected' - status = status.format(cursor.rowcount, - '' if cursor.rowcount == 1 else 's') + _logger.debug("No rows in result.") + status = FormattedText([('', f'Query OK, {cursor.rowcount} row{plural} affected')]) + + if cursor.warning_count > 0: + plural = '' if cursor.warning_count == 1 else 's' + comma = FormattedText([('', ', ')]) + warning_count = FormattedText([('class:output.status.warning-count', f'{cursor.warning_count} warning{plural}')]) + status.extend(comma) + status.extend(warning_count) - return (title, cursor if cursor.description else None, headers, status) + return SQLResult(preamble=preamble, header=header, rows=cursor, status=status) - def tables(self): + def tables(self) -> Generator[tuple[str], None, None]: """Yields table names""" + assert isinstance(self.conn, Connection) with self.conn.cursor() as cur: - _logger.debug('Tables Query. sql: %r', self.tables_query) + _logger.debug("Tables Query. sql: %r", self.tables_query) cur.execute(self.tables_query) - for row in cur: - yield row + yield from cur - def table_columns(self): - """Yields (table name, column name) pairs""" + def table_columns(self, schema: str | None = None) -> Generator[tuple[str, str], None, None]: + """Yields (table name, column name) pairs for *schema* (default: current database).""" + target = schema if schema is not None else self.dbname + assert isinstance(self.conn, Connection) + with self.conn.cursor() as cur: + _logger.debug("Columns Query. sql: %r schema: %r", self.table_columns_query, target) + cur.execute(self.table_columns_query, (target,)) + yield from cur + + def enum_values(self, schema: str | None = None) -> Generator[tuple[str, str, list[str]], None, None]: + """Yields (table name, column name, enum values) tuples for *schema*.""" + target = schema if schema is not None else self.dbname + assert isinstance(self.conn, Connection) with self.conn.cursor() as cur: - _logger.debug('Columns Query. sql: %r', self.table_columns_query) - cur.execute(self.table_columns_query % self.dbname) - for row in cur: - yield row + _logger.debug("Enum Values Query. sql: %r schema: %r", self.enum_values_query, target) + cur.execute(self.enum_values_query, (target,)) + for table_name, column_name, column_type in cur: + values = self._parse_enum_values(column_type) + if values: + yield (table_name, column_name, values) + + def foreign_keys(self, schema: str | None = None) -> Generator[tuple[str, str, str, str], None, None]: + """Yields (table_name, column_name, referenced_table_name, referenced_column_name) tuples for *schema*.""" + target = schema if schema is not None else self.dbname + assert isinstance(self.conn, Connection) + with self.conn.cursor() as cur: + _logger.debug("Foreign Keys Query. sql: %r schema: %r", self.foreign_keys_query, target) + try: + cur.execute(self.foreign_keys_query, (target,)) + yield from cur + except Exception as e: + _logger.error('No foreign key completions due to %r', e) - def databases(self): + def databases(self) -> list[str]: + assert isinstance(self.conn, Connection) with self.conn.cursor() as cur: - _logger.debug('Databases Query. sql: %r', self.databases_query) + _logger.debug("Databases Query. sql: %r", self.databases_query) cur.execute(self.databases_query) return [x[0] for x in cur.fetchall()] - def functions(self): - """Yields tuples of (schema_name, function_name)""" + def functions(self, schema: str | None = None) -> Generator[tuple[str, str], None, None]: + """Yields tuples of (schema_name, function_name) for *schema*.""" + + target = schema if schema is not None else self.dbname + assert isinstance(self.conn, Connection) + with self.conn.cursor() as cur: + _logger.debug("Functions Query. sql: %r schema: %r", self.functions_query, target) + cur.execute(self.functions_query, (target,)) + yield from cur + + def procedures(self, schema: str | None = None) -> Generator[tuple, None, None]: + """Yields tuples of (procedure_name, ) for *schema*.""" + + target = schema if schema is not None else self.dbname + assert isinstance(self.conn, Connection) + with self.conn.cursor() as cur: + _logger.debug("Procedures Query. sql: %r schema: %r", self.procedures_query, target) + try: + cur.execute(self.procedures_query, (target,)) + except pymysql.DatabaseError as e: + _logger.error('No procedure completions due to %r', e) + yield () + else: + yield from cur + def character_sets(self) -> Generator[tuple, None, None]: + """Yields tuples of (character_set_name, )""" + + assert isinstance(self.conn, Connection) with self.conn.cursor() as cur: - _logger.debug('Functions Query. sql: %r', self.functions_query) - cur.execute(self.functions_query % self.dbname) - for row in cur: - yield row + _logger.debug("Character sets Query. sql: %r", self.character_sets_query) + try: + cur.execute(self.character_sets_query) + except pymysql.DatabaseError as e: + _logger.error('No character_set completions due to %r', e) + yield () + else: + yield from cur - def show_candidates(self): + def collations(self) -> Generator[tuple, None, None]: + """Yields tuples of (collation_name, )""" + + assert isinstance(self.conn, Connection) with self.conn.cursor() as cur: - _logger.debug('Show Query. sql: %r', self.show_candidates_query) + _logger.debug("Collations Query. sql: %r", self.collations_query) + try: + cur.execute(self.collations_query) + except pymysql.DatabaseError as e: + _logger.error('No collations completions due to %r', e) + yield () + else: + yield from cur + + def show_candidates(self) -> Generator[tuple, None, None]: + assert isinstance(self.conn, Connection) + with self.conn.cursor() as cur: + _logger.debug("Show Query. sql: %r", self.show_candidates_query) try: cur.execute(self.show_candidates_query) except pymysql.DatabaseError as e: - _logger.error('No show completions due to %r', e) - yield '' + _logger.error("No show completions due to %r", e) + yield () else: for row in cur: - yield (row[0].split(None, 1)[-1], ) + yield (row[0].split(None, 1)[-1],) - def users(self): + def users(self) -> Generator[tuple, None, None]: + assert isinstance(self.conn, Connection) with self.conn.cursor() as cur: - _logger.debug('Users Query. sql: %r', self.users_query) + _logger.debug("Users Query. sql: %r", self.users_query) try: cur.execute(self.users_query) except pymysql.DatabaseError as e: - _logger.error('No user completions due to %r', e) - yield '' + _logger.error("No user completions due to %r", e) + yield () else: - for row in cur: - yield row + yield from cur + + def now(self) -> datetime.datetime: + assert isinstance(self.conn, Connection) + with self.conn.cursor() as cur: + _logger.debug("Now Query. sql: %r", self.now_query) + cur.execute(self.now_query) + if one := cur.fetchone(): + return one[0] + else: + return datetime.datetime.now() - def get_connection_id(self): + def get_connection_id(self) -> int | None: if not self.connection_id: self.reset_connection_id() return self.connection_id - def reset_connection_id(self): + def reset_connection_id(self) -> None: # Remember current connection id - _logger.debug('Get current connection id') + _logger.debug("Get current connection id") try: - res = self.run('select connection_id()') - for title, cur, headers, status in res: - self.connection_id = cur.fetchone()[0] + results = self.run("select connection_id()") + for result in results: + cur = result.rows + if isinstance(cur, Cursor): + v = cur.fetchone() + self.connection_id = v[0] if v is not None else -1 + else: + raise ValueError except Exception as e: # See #1054 self.connection_id = -1 - _logger.error('Failed to get connection id: %s', e) + _logger.error("Failed to get connection id: %s", e) else: - _logger.debug('Current connection id: %s', self.connection_id) + _logger.debug("Current connection id: %s", self.connection_id) - def change_db(self, db): + def change_db(self, db: str) -> None: + assert isinstance(self.conn, Connection) self.conn.select_db(db) self.dbname = db - def _create_ssl_ctx(self, sslp): - import ssl + @staticmethod + def _connect_sandbox(conn: Connection) -> None: + """Connect in sandbox mode, performing only the handshake. + + pymysql's normal connect() runs post-handshake queries (SET NAMES, + SET AUTOCOMMIT, init_command) that all fail with ER_MUST_CHANGE_PASSWORD + in sandbox mode. This method performs the raw socket connection and + authentication handshake only. + """ + # Reuse pymysql internals for the handshake + auth, but + # temporarily stub out set_character_set so it becomes a no-op. + original_set_charset = conn.set_character_set + conn.set_character_set = lambda *_args, **_kwargs: None # type: ignore[assignment] + try: + conn.connect() + finally: + conn.set_character_set = original_set_charset # type: ignore[assignment] + def _create_ssl_ctx(self, sslp: dict) -> ssl.SSLContext: ca = sslp.get("ca") capath = sslp.get("capath") hasnoca = ca is None and capath is None @@ -373,8 +631,7 @@ def _create_ssl_ctx(self, sslp): if "cipher" in sslp: ctx.set_ciphers(sslp["cipher"]) - # raise this default to v1.1 or v1.2? - ctx.minimum_version = ssl.TLSVersion.TLSv1 + ctx.minimum_version = ssl.TLSVersion.TLSv1_2 if "tls_version" in sslp: tls_version = sslp["tls_version"] @@ -392,6 +649,13 @@ def _create_ssl_ctx(self, sslp): ctx.minimum_version = ssl.TLSVersion.TLSv1_3 ctx.maximum_version = ssl.TLSVersion.TLSv1_3 else: - _logger.error('Invalid tls version: %s', tls_version) + _logger.error("Invalid tls version: %s", tls_version) return ctx + + def close(self) -> None: + if self.conn is not None: + try: + self.conn.close() + except pymysql.err.Error: + pass diff --git a/mycli/types.py b/mycli/types.py new file mode 100644 index 00000000..207d62d9 --- /dev/null +++ b/mycli/types.py @@ -0,0 +1,4 @@ +from collections import namedtuple + +# Query tuples are used for maintaining history +Query = namedtuple("Query", ["query", "successful", "mutating"]) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..060cad21 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,160 @@ +[project] +name = "mycli" +dynamic = ["version"] +description = "CLI for MySQL Database. With auto-completion and syntax highlighting." +readme = "README.md" +requires-python = ">=3.10" +license = "BSD-3-Clause" +authors = [{ name = "Mycli Core Team" }] + +dependencies = [ + "click ~= 8.3.1", + "clickdc ~= 0.1.1", + "cryptography ~= 46.0.5", + "Pygments ~= 2.19.2", + "prompt_toolkit>=3.0.41,<4.0.0", + "PyMySQL ~= 1.1.2", + "sqlparse>=0.3.0,<0.6.0", + "sqlglot[c] ~= 30.7.0", + "configobj ~= 5.0.9", + "cli_helpers[styles] ~= 2.15.0", + "wcwidth ~= 0.6.0", + "pyperclip ~= 1.11.0", + "pycryptodomex ~= 3.23.0", + "pyfzf ~= 0.3.1", + "rapidfuzz ~= 3.14.3", + "keyring ~= 25.7.0", +] + +[project.urls] +Homepage = 'https://mycli.net' +Documentation = 'https://mycli.net/docs' +Source = 'https://github.com/dbcli/mycli' +Issues = 'https://github.com/dbcli/mycli/issues' +Changelog = 'https://github.com/dbcli/mycli/blob/main/changelog.md' + +[build-system] +requires = ["setuptools>=64.0", "setuptools-scm>=8"] +build-backend = "setuptools.build_meta" + +[tool.setuptools_scm] + + +[project.optional-dependencies] +ssh = [ + "paramiko ~= 3.5.1", + "sshtunnel ~= 0.4.0", +] +llm = [ + "llm ~= 0.30.0", + "pydantic_core ~= 2.41.5", # Required by llm; force a newer version + "setuptools == 82.*", # Required by llm commands to install models + "pip == 26.*", +] +all = [ + "mycli[ssh]", + "mycli[llm]", +] +dev = [ + "behave ~= 1.3.3", + "coverage ~= 7.13.4", + "mypy ~= 1.19.1", + "pexpect ~= 4.9.0", + "pytest ~= 9.0.2", + "pytest-cov ~= 7.0.0", + "pytest-random-order ~= 1.2.0", + "tox ~= 4.35.0", + "pdbpp ~= 0.11.7", + "paramiko ~= 3.5.1", + "sshtunnel ~= 0.4.0", + "llm ~= 0.30.0", + "pydantic_core ~= 2.41.5", # Required by llm; force a newer version + "setuptools == 82.*", # Required by llm commands to install models + "pip == 26.*", + "ruff ~= 0.15.0", +] + +[project.scripts] +mycli = "mycli.main:main" + +[tool.setuptools.package-data] +mycli = ["myclirc", "AUTHORS", "SPONSORS", "TIPS"] + +[tool.setuptools.packages.find] +include = ["mycli*"] + +[tool.ruff] +target-version = 'py310' +line-length = 140 + +[tool.ruff.lint] +select = ['A', 'B', 'I', 'E', 'W', 'F', 'C4', 'PIE', 'TID'] +ignore = [ + 'B005', # Multi-character strip() + 'E401', # Multiple imports on one line + 'E402', # Module level import not at top of file + 'PIE808', # range() starting with 0 + # https://docs.astral.sh/ruff/formatter/#conflicting-lint-rules + 'E111', # indentation-with-invalid-multiple + 'E114', # indentation-with-invalid-multiple-comment + 'E117', # over-indented + 'W191', # tab-indentation +] + +[tool.ruff.lint.isort] +force-sort-within-sections = true +known-first-party = ['mycli', 'test', 'steps'] + +[tool.ruff.lint.flake8-tidy-imports] +ban-relative-imports = 'all' + +[tool.ruff.format] +preview = true +quote-style = 'preserve' +exclude = ['build', 'mycli_dev'] + +[tool.mypy] +pretty = true +strict_equality = true +ignore_missing_imports = true +warn_unreachable = true +warn_redundant_casts = true +warn_no_return = true +warn_unused_configs = true +show_column_numbers = true +exclude = ['^build/', '^dist/'] + +[tool.tox] +env_list = ['python'] +requires = ['tox>=4.20'] + +[tool.tox.env_run_base] +skip_install = true +deps = ['uv'] +passenv = ['PYTEST_HOST', + 'PYTEST_USER', + 'PYTEST_PASSWORD', + 'PYTEST_PORT', + 'PYTEST_CHARSET'] +commands = [['uv', 'pip', 'install', '-e', '.[dev,ssh,llm]'], + ['coverage', 'run', '-m', 'pytest', '-v', 'test'], + ['coverage', 'report', '-m', '--sort=Miss'], + ['behave', 'test/features']] +commands_post = [['rm', '-f', '--', './.myclirc']] +allowlist_externals = ['rm'] + +[tool.tox.env.style] +skip_install = true +deps = ['ruff'] +commands = [['ruff', 'check'], + ['ruff', 'format', '--diff']] + +[tool.pytest] +addopts = ['--ignore=mycli/packages/paramiko_stub/__init__.py', '--random-order'] + +[tool.coverage.run] +source = ['mycli'] +omit = [ + # deprecated + 'mycli/packages/paramiko_stub/__init__.py', +] diff --git a/pytest.ini b/pytest.ini deleted file mode 100644 index 5422131c..00000000 --- a/pytest.ini +++ /dev/null @@ -1,2 +0,0 @@ -[pytest] -addopts = --ignore=mycli/packages/paramiko_stub/__init__.py diff --git a/release.py b/release.py deleted file mode 100755 index 62daa802..00000000 --- a/release.py +++ /dev/null @@ -1,119 +0,0 @@ -"""A script to publish a release of mycli to PyPI.""" - -from optparse import OptionParser -import re -import subprocess -import sys - -import click - -DEBUG = False -CONFIRM_STEPS = False -DRY_RUN = False - - -def skip_step(): - """ - Asks for user's response whether to run a step. Default is yes. - :return: boolean - """ - global CONFIRM_STEPS - - if CONFIRM_STEPS: - return not click.confirm('--- Run this step?', default=True) - return False - - -def run_step(*args): - """ - Prints out the command and asks if it should be run. - If yes (default), runs it. - :param args: list of strings (command and args) - """ - global DRY_RUN - - cmd = args - print(' '.join(cmd)) - if skip_step(): - print('--- Skipping...') - elif DRY_RUN: - print('--- Pretending to run...') - else: - subprocess.check_output(cmd) - - -def version(version_file): - _version_re = re.compile( - r'__version__\s+=\s+(?P[\'"])(?P.*)(?P=quote)') - - with open(version_file) as f: - ver = _version_re.search(f.read()).group('version') - - return ver - - -def commit_for_release(version_file, ver): - run_step('git', 'reset') - run_step('git', 'add', version_file) - run_step('git', 'commit', '--message', - 'Releasing version {}'.format(ver)) - - -def create_git_tag(tag_name): - run_step('git', 'tag', tag_name) - - -def create_distribution_files(): - run_step('python', 'setup.py', 'sdist', 'bdist_wheel') - - -def upload_distribution_files(): - run_step('twine', 'upload', 'dist/*') - - -def push_to_github(): - run_step('git', 'push', 'origin', 'main') - - -def push_tags_to_github(): - run_step('git', 'push', '--tags', 'origin') - - -def checklist(questions): - for question in questions: - if not click.confirm('--- {}'.format(question), default=False): - sys.exit(1) - - -if __name__ == '__main__': - if DEBUG: - subprocess.check_output = lambda x: x - - ver = version('mycli/__init__.py') - - parser = OptionParser() - parser.add_option( - "-c", "--confirm-steps", action="store_true", dest="confirm_steps", - default=False, help=("Confirm every step. If the step is not " - "confirmed, it will be skipped.") - ) - parser.add_option( - "-d", "--dry-run", action="store_true", dest="dry_run", - default=False, help="Print out, but not actually run any steps." - ) - - popts, pargs = parser.parse_args() - CONFIRM_STEPS = popts.confirm_steps - DRY_RUN = popts.dry_run - - print('Releasing Version:', ver) - - if not click.confirm('Are you sure?', default=False): - sys.exit(1) - - commit_for_release('mycli/__init__.py', ver) - create_git_tag('v{}'.format(ver)) - create_distribution_files() - push_to_github() - push_tags_to_github() - upload_distribution_files() diff --git a/requirements-dev.txt b/requirements-dev.txt deleted file mode 100644 index 603efa20..00000000 --- a/requirements-dev.txt +++ /dev/null @@ -1,17 +0,0 @@ -pytest>=3.3.0 -pytest-cov>=2.4.0 -tox -twine>=1.12.1 -behave>=1.2.4 -pexpect>=3.3 -coverage>=5.0.4 -autopep8==1.3.3 -colorama>=0.4.1 -git+https://github.com/hayd/pep8radius.git # --error-status option not released -click>=7.0 -paramiko==2.11.0 -pyperclip>=1.8.1 -importlib_resources>=5.0.0 -pyaes>=1.6.1 -sqlglot>=5.1.3 -setuptools<=71.1.0 diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index e533c7b7..00000000 --- a/setup.cfg +++ /dev/null @@ -1,18 +0,0 @@ -[bdist_wheel] -universal = 1 - -[tool:pytest] -addopts = --capture=sys - --showlocals - --doctest-modules - --doctest-ignore-import-errors - --ignore=setup.py - --ignore=mycli/magic.py - --ignore=mycli/packages/parseutils.py - --ignore=test/features - -[pep8] -rev = master -docformatter = True -diff = True -error-status = True diff --git a/setup.py b/setup.py deleted file mode 100755 index c7f93331..00000000 --- a/setup.py +++ /dev/null @@ -1,127 +0,0 @@ -#!/usr/bin/env python - -import ast -import re -import subprocess -import sys - -from setuptools import Command, find_packages, setup -from setuptools.command.test import test as TestCommand - -_version_re = re.compile(r'__version__\s+=\s+(.*)') - -with open('mycli/__init__.py') as f: - version = ast.literal_eval(_version_re.search( - f.read()).group(1)) - -description = 'CLI for MySQL Database. With auto-completion and syntax highlighting.' - -install_requirements = [ - 'click >= 7.0', - # Pinning cryptography is not needed after paramiko 2.11.0. Correct it - 'cryptography >= 1.0.0', - # 'Pygments>=1.6,<=2.11.1', - 'Pygments>=1.6', - 'prompt_toolkit>=3.0.6,<4.0.0', - 'PyMySQL >= 0.9.2', - 'sqlparse>=0.3.0,<0.5.0', - 'sqlglot>=5.1.3', - 'configobj >= 5.0.5', - 'cli_helpers[styles] >= 2.2.1', - 'pyperclip >= 1.8.1', - 'pyaes >= 1.6.1', - 'pyfzf >= 0.3.1', -] - -if sys.version_info.minor < 9: - install_requirements.append('importlib_resources >= 5.0.0') - - -class lint(Command): - description = 'check code against PEP 8 (and fix violations)' - - user_options = [ - ('branch=', 'b', 'branch/revision to compare against (e.g. main)'), - ('fix', 'f', 'fix the violations in place'), - ('error-status', 'e', 'return an error code on failed PEP check'), - ] - - def initialize_options(self): - """Set the default options.""" - self.branch = 'main' - self.fix = False - self.error_status = True - - def finalize_options(self): - pass - - def run(self): - cmd = 'pep8radius {}'.format(self.branch) - if self.fix: - cmd += ' --in-place' - if self.error_status: - cmd += ' --error-status' - sys.exit(subprocess.call(cmd, shell=True)) - - -class test(TestCommand): - - user_options = [ - ('pytest-args=', 'a', 'Arguments to pass to pytest'), - ('behave-args=', 'b', 'Arguments to pass to pytest') - ] - - def initialize_options(self): - TestCommand.initialize_options(self) - self.pytest_args = '' - self.behave_args = '--no-capture' - - def run_tests(self): - unit_test_errno = subprocess.call( - 'pytest test/ ' + self.pytest_args, - shell=True - ) - cli_errno = subprocess.call( - 'behave test/features ' + self.behave_args, - shell=True - ) - subprocess.run(['git', 'checkout', '--', 'test/myclirc'], check=False) - sys.exit(unit_test_errno or cli_errno) - - -setup( - name='mycli', - author='Mycli Core Team', - author_email='mycli-dev@googlegroups.com', - version=version, - url='http://mycli.net', - packages=find_packages(exclude=['test*']), - package_data={'mycli': ['myclirc', 'AUTHORS', 'SPONSORS']}, - description=description, - long_description=description, - install_requires=install_requirements, - entry_points={ - 'console_scripts': ['mycli = mycli.main:cli'], - }, - cmdclass={'lint': lint, 'test': test}, - python_requires=">=3.7", - classifiers=[ - 'Intended Audience :: Developers', - 'License :: OSI Approved :: BSD License', - 'Operating System :: Unix', - 'Programming Language :: Python', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', - 'Programming Language :: Python :: 3.10', - 'Programming Language :: SQL', - 'Topic :: Database', - 'Topic :: Database :: Front-Ends', - 'Topic :: Software Development', - 'Topic :: Software Development :: Libraries :: Python Modules', - ], - extras_require={ - 'ssh': ['paramiko'], - }, -) diff --git a/test/conftest.py b/test/conftest.py deleted file mode 100644 index 1325596d..00000000 --- a/test/conftest.py +++ /dev/null @@ -1,29 +0,0 @@ -import pytest -from .utils import (HOST, USER, PASSWORD, PORT, CHARSET, create_db, - db_connection, SSH_USER, SSH_HOST, SSH_PORT) -import mycli.sqlexecute - - -@pytest.fixture(scope="function") -def connection(): - create_db('mycli_test_db') - connection = db_connection('mycli_test_db') - yield connection - - connection.close() - - -@pytest.fixture -def cursor(connection): - with connection.cursor() as cur: - return cur - - -@pytest.fixture -def executor(connection): - return mycli.sqlexecute.SQLExecute( - database='mycli_test_db', user=USER, - host=HOST, password=PASSWORD, port=PORT, socket=None, charset=CHARSET, - local_infile=False, ssl=None, ssh_user=SSH_USER, ssh_host=SSH_HOST, - ssh_port=SSH_PORT, ssh_password=None, ssh_key_filename=None - ) diff --git a/test/features/basic_commands.feature b/test/features/basic_commands.feature index a12e8992..74a39d9c 100644 --- a/test/features/basic_commands.feature +++ b/test/features/basic_commands.feature @@ -1,5 +1,7 @@ Feature: run the cli, call the help command, + check our application name, + insert the date, exit the cli Scenario: run "\?" command @@ -14,6 +16,10 @@ Feature: run the cli, When we run query to check application_name then we see found + Scenario: insert the date + When we send "ctrl + o, ctrl + d" + then we see the date + Scenario: run the cli and exit When we send "ctrl + d" then dbcli exits diff --git a/test/features/crud_table.feature b/test/features/crud_table.feature index 3384efd7..1e639b04 100644 --- a/test/features/crud_table.feature +++ b/test/features/crud_table.feature @@ -38,6 +38,9 @@ Feature: manipulate tables: and we answer the destructive warning with "n" then we see text "Wise choice!" + # TODO (amjith). This scenario fails in GH actions but only in 3.12. Unable + # to reproduce locally. + @skip_py312 Scenario: no destructive warning if disabled in config When we run dbcli with --no-warn and we query "create table blabla(x integer);" diff --git a/test/features/db_utils.py b/test/features/db_utils.py index be550e9f..ff649dd1 100644 --- a/test/features/db_utils.py +++ b/test/features/db_utils.py @@ -1,8 +1,11 @@ +# type: ignore + import pymysql +from mycli.constants import DEFAULT_CHARSET, DEFAULT_HOST, DEFAULT_PORT + -def create_db(hostname='localhost', port=3306, username=None, - password=None, dbname=None): +def create_db(hostname=DEFAULT_HOST, port=DEFAULT_PORT, username=None, password=None, dbname=None): """Create test database. :param hostname: string @@ -14,17 +17,12 @@ def create_db(hostname='localhost', port=3306, username=None, """ cn = pymysql.connect( - host=hostname, - port=port, - user=username, - password=password, - charset='utf8mb4', - cursorclass=pymysql.cursors.DictCursor + host=hostname, port=port, user=username, password=password, charset=DEFAULT_CHARSET, cursorclass=pymysql.cursors.DictCursor ) with cn.cursor() as cr: - cr.execute('drop database if exists ' + dbname) - cr.execute('create database ' + dbname) + cr.execute("drop database if exists " + dbname) + cr.execute("create database " + dbname) cn.close() @@ -49,15 +47,14 @@ def create_cn(hostname, port, password, username, dbname): user=username, password=password, db=dbname, - charset='utf8mb4', - cursorclass=pymysql.cursors.DictCursor + charset=DEFAULT_CHARSET, + cursorclass=pymysql.cursors.DictCursor, ) return cn -def drop_db(hostname='localhost', port=3306, username=None, - password=None, dbname=None): +def drop_db(hostname=DEFAULT_HOST, port=DEFAULT_PORT, username=None, password=None, dbname=None): """Drop database. :param hostname: string @@ -73,12 +70,12 @@ def drop_db(hostname='localhost', port=3306, username=None, user=username, password=password, db=dbname, - charset='utf8mb4', - cursorclass=pymysql.cursors.DictCursor + charset=DEFAULT_CHARSET, + cursorclass=pymysql.cursors.DictCursor, ) with cn.cursor() as cr: - cr.execute('drop database if exists ' + dbname) + cr.execute("drop database if exists " + dbname) close_cn(cn) diff --git a/test/features/environment.py b/test/features/environment.py index 1ea0f086..0448fc24 100644 --- a/test/features/environment.py +++ b/test/features/environment.py @@ -1,104 +1,80 @@ +# type: ignore + import os import shutil import sys -from tempfile import mkstemp +from tempfile import NamedTemporaryFile import db_utils as dbutils import fixture_utils as fixutils import pexpect +from mycli.constants import DEFAULT_HOST, DEFAULT_PORT, DEFAULT_USER from steps.wrappers import run_cli, wait_prompt +from test.utils import TEMPFILE_PREFIX -test_log_file = os.path.join(os.environ['HOME'], '.mycli.test.log') +test_log_file = os.path.join(os.environ["HOME"], ".mycli.test.log") -SELF_CONNECTING_FEATURES = ( - 'test/features/connection.feature', -) +SELF_CONNECTING_FEATURES = ("test/features/connection.feature",) -MY_CNF_PATH = os.path.expanduser('~/.my.cnf') -MY_CNF_BACKUP_PATH = f'{MY_CNF_PATH}.backup' -MYLOGIN_CNF_PATH = os.path.expanduser('~/.mylogin.cnf') -MYLOGIN_CNF_BACKUP_PATH = f'{MYLOGIN_CNF_PATH}.backup' +MY_CNF_PATH = os.path.expanduser("~/.my.cnf") +MY_CNF_BACKUP_PATH = f"{MY_CNF_PATH}.backup" +MYLOGIN_CNF_PATH = os.path.expanduser("~/.mylogin.cnf") +MYLOGIN_CNF_BACKUP_PATH = f"{MYLOGIN_CNF_PATH}.backup" def get_db_name_from_context(context): - return context.config.userdata.get( - 'my_test_db', None - ) or "mycli_behave_tests" - + return context.config.userdata.get("my_test_db", None) or "mycli_behave_tests" def before_all(context): """Set env parameters.""" - os.environ['LINES'] = "100" - os.environ['COLUMNS'] = "100" - os.environ['EDITOR'] = 'ex' - os.environ['LC_ALL'] = 'en_US.UTF-8' - os.environ['PROMPT_TOOLKIT_NO_CPR'] = '1' - os.environ['MYCLI_HISTFILE'] = os.devnull - - test_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) - login_path_file = os.path.join(test_dir, 'mylogin.cnf') -# os.environ['MYSQL_TEST_LOGIN_FILE'] = login_path_file - - context.package_root = os.path.abspath( - os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) - - os.environ["COVERAGE_PROCESS_START"] = os.path.join(context.package_root, - '.coveragerc') - + os.environ["LINES"] = "100" + os.environ["COLUMNS"] = "100" + os.environ["VISUAL"] = "ex" + os.environ["EDITOR"] = "ex" + os.environ["LC_ALL"] = "en_US.UTF-8" + os.environ["PROMPT_TOOLKIT_NO_CPR"] = "1" + os.environ["MYCLI_HISTFILE"] = os.devnull + + # test_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) + # login_path_file = os.path.join(test_dir, "mylogin.cnf") + # os.environ['MYSQL_TEST_LOGIN_FILE'] = login_path_file + + context.package_root = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) context.exit_sent = False - vi = '_'.join([str(x) for x in sys.version_info[:3]]) + vi = "_".join([str(x) for x in sys.version_info[:3]]) db_name = get_db_name_from_context(context) - db_name_full = '{0}_{1}'.format(db_name, vi) + db_name_full = f"{db_name}_{vi}" # Store get params from config/environment variables context.conf = { - 'host': context.config.userdata.get( - 'my_test_host', - os.getenv('PYTEST_HOST', 'localhost') - ), - 'port': context.config.userdata.get( - 'my_test_port', - int(os.getenv('PYTEST_PORT', '3306')) - ), - 'user': context.config.userdata.get( - 'my_test_user', - os.getenv('PYTEST_USER', 'root') - ), - 'pass': context.config.userdata.get( - 'my_test_pass', - os.getenv('PYTEST_PASSWORD', None) - ), - 'cli_command': context.config.userdata.get( - 'my_cli_command', None) or - sys.executable + ' -c "import coverage ; coverage.process_startup(); import mycli.main; mycli.main.cli()"', - 'dbname': db_name, - 'dbname_tmp': db_name_full + '_tmp', - 'vi': vi, - 'pager_boundary': '---boundary---', + "host": context.config.userdata.get("my_test_host", os.getenv("PYTEST_HOST", DEFAULT_HOST)), + "port": context.config.userdata.get("my_test_port", int(os.getenv("PYTEST_PORT", DEFAULT_PORT))), + "user": context.config.userdata.get("my_test_user", os.getenv("PYTEST_USER", DEFAULT_USER)), + "pass": context.config.userdata.get("my_test_pass", os.getenv("PYTEST_PASSWORD", None)), + "cli_command": context.config.userdata.get("my_cli_command", None) + or sys.executable + ' -c "import coverage ; coverage.process_startup(); import mycli.main; mycli.main.click_entrypoint()"', + "dbname": db_name, + "dbname_tmp": db_name_full + "_tmp", + "vi": vi, + "pager_boundary": "---boundary---", } - _, my_cnf = mkstemp() - with open(my_cnf, 'w') as f: - f.write( - '[client]\n' - 'pager={0} {1} {2}\n'.format( - sys.executable, os.path.join(context.package_root, - 'test/features/wrappager.py'), - context.conf['pager_boundary']) + with NamedTemporaryFile(prefix=TEMPFILE_PREFIX, mode='w', delete=False) as my_cnf: + my_cnf.write( + f'[client]\npager={sys.executable} ' + f'{os.path.join(context.package_root, "test/features/wrappager.py")} {context.conf["pager_boundary"]}\n' ) - context.conf['defaults-file'] = my_cnf - context.conf['myclirc'] = os.path.join(context.package_root, 'test', - 'myclirc') + context.conf["defaults-file"] = my_cnf.name + context.conf["myclirc"] = os.path.join(context.package_root, "test", "myclirc") - context.cn = dbutils.create_db(context.conf['host'], context.conf['port'], - context.conf['user'], - context.conf['pass'], - context.conf['dbname']) + context.cn = dbutils.create_db( + context.conf["host"], context.conf["port"], context.conf["user"], context.conf["pass"], context.conf["dbname"] + ) context.fixture_data = fixutils.read_fixture_files() @@ -106,12 +82,15 @@ def before_all(context): def after_all(context): """Unset env parameters.""" dbutils.close_cn(context.cn) - dbutils.drop_db(context.conf['host'], context.conf['port'], - context.conf['user'], context.conf['pass'], - context.conf['dbname']) + dbutils.drop_db(context.conf["host"], context.conf["port"], context.conf["user"], context.conf["pass"], context.conf["dbname"]) + try: + if os.path.exists(context.conf["defaults-file"]): + os.remove(context.conf["defaults-file"]) + except Exception: + pass # Restore env vars. - #for k, v in context.pgenv.items(): + # for k, v in context.pgenv.items(): # if k in os.environ and v is None: # del os.environ[k] # elif v: @@ -123,8 +102,11 @@ def before_step(context, _): def before_scenario(context, arg): - with open(test_log_file, 'w') as f: - f.write('') + # Skip scenarios marked skip_py312 when running on Python 3.12 + if sys.version_info[:2] == (3, 12) and "skip_py312" in arg.tags: + arg.skip("Skipped on Python 3.12") + with open(test_log_file, "w") as f: + f.write("") if arg.location.filename not in SELF_CONNECTING_FEATURES: run_cli(context) wait_prompt(context) @@ -140,23 +122,18 @@ def after_scenario(context, _): """Cleans up after each test complete.""" with open(test_log_file) as f: for line in f: - if 'error' in line.lower(): - raise RuntimeError(f'Error in log file: {line}') + if "error" in line.lower(): + raise RuntimeError(f"Error in log file: {line}") - if hasattr(context, 'cli') and not context.exit_sent: + if hasattr(context, "cli") and not context.exit_sent: # Quit nicely. if not context.atprompt: - user = context.conf['user'] - host = context.conf['host'] + user = context.conf["user"] + host = context.conf["host"] dbname = context.currentdb - context.cli.expect_exact( - '{0}@{1}:{2}>'.format( - user, host, dbname - ), - timeout=5 - ) - context.cli.sendcontrol('c') - context.cli.sendcontrol('d') + context.cli.expect_exact(f"{user}@{host}:{dbname}>", timeout=5) + context.cli.sendcontrol("c") + context.cli.sendcontrol("d") context.cli.expect_exact(pexpect.EOF, timeout=5) if os.path.exists(MY_CNF_BACKUP_PATH): diff --git a/test/features/fixture_data/help.txt b/test/features/fixture_data/help.txt deleted file mode 100644 index deb499a4..00000000 --- a/test/features/fixture_data/help.txt +++ /dev/null @@ -1,24 +0,0 @@ -+--------------------------+-----------------------------------------------+ -| Command | Description | -|--------------------------+-----------------------------------------------| -| \# | Refresh auto-completions. | -| \? | Show Help. | -| \c[onnect] database_name | Change to a new database. | -| \d [pattern] | List or describe tables, views and sequences. | -| \dT[S+] [pattern] | List data types | -| \df[+] [pattern] | List functions. | -| \di[+] [pattern] | List indexes. | -| \dn[+] [pattern] | List schemas. | -| \ds[+] [pattern] | List sequences. | -| \dt[+] [pattern] | List tables. | -| \du[+] [pattern] | List roles. | -| \dv[+] [pattern] | List views. | -| \e [file] | Edit the query with external editor. | -| \l | List databases. | -| \n[+] [name] | List or execute named queries. | -| \nd [name [query]] | Delete a named query. | -| \ns name query | Save a named query. | -| \refresh | Refresh auto-completions. | -| \timing | Toggle timing of commands. | -| \x | Toggle expanded output. | -+--------------------------+-----------------------------------------------+ diff --git a/test/features/fixture_data/help_commands.txt b/test/features/fixture_data/help_commands.txt index 2c06d5d2..26a23914 100644 --- a/test/features/fixture_data/help_commands.txt +++ b/test/features/fixture_data/help_commands.txt @@ -1,31 +1,38 @@ -+-------------+----------------------------+------------------------------------------------------------+ -| Command | Shortcut | Description | -+-------------+----------------------------+------------------------------------------------------------+ -| \G | \G | Display current query results vertically. | -| \clip | \clip | Copy query to the system clipboard. | -| \dt | \dt[+] [table] | List or describe tables. | -| \e | \e | Edit command with editor (uses $EDITOR). | -| \f | \f [name [args..]] | List or execute favorite queries. | -| \fd | \fd [name] | Delete a favorite query. | -| \fs | \fs name query | Save a favorite query. | -| \l | \l | List databases. | -| \once | \o [-o] filename | Append next result to an output file (overwrite using -o). | -| \pipe_once | \| command | Send next result to a subprocess. | -| \timing | \t | Toggle timing of commands. | -| connect | \r | Reconnect to the database. Optional database argument. | -| exit | \q | Exit. | -| help | \? | Show this help. | -| nopager | \n | Disable pager, print to stdout. | -| notee | notee | Stop writing results to an output file. | -| pager | \P [command] | Set PAGER. Print the query results via PAGER. | -| prompt | \R | Change prompt format. | -| quit | \q | Quit. | -| rehash | \# | Refresh auto-completions. | -| source | \. filename | Execute commands from file. | -| status | \s | Get status information from the server. | -| system | system [command] | Execute a system shell commmand. | -| tableformat | \T | Change the table format used to output results. | -| tee | tee [-o] filename | Append all results to an output file (overwrite using -o). | -| use | \u | Change to a new database. | -| watch | watch [seconds] [-c] query | Executes the query every [seconds] seconds (by default 5). | -+-------------+----------------------------+------------------------------------------------------------+ ++----------------+----------+---------------------------------+-------------------------------------------------------------+ +| Command | Shortcut | Usage | Description | ++----------------+----------+---------------------------------+-------------------------------------------------------------+ +| \bug | | \bug | File a bug on GitHub. | +| \clip | | \clip | Copy query to the system clipboard. | +| \dt | | \dt[+] [table] | List or describe tables. | +| \edit | \e | \edit | \edit | Edit query with editor (uses $VISUAL or $EDITOR). | +| \f | | \f [name [args..]] | List or execute favorite queries. | +| \fd | | \fd | Delete a favorite query. | +| \fs | | \fs | Save a favorite query. | +| \g | | \g | Display query results (mnemonic: go). | +| \G | | \G | Display query results vertically. | +| \l | | \l | List databases. | +| \llm | \ai | \llm [arguments] | Interrogate an LLM. See "\llm help". | +| \once | \o | \once [-o] | Append next result to an output file (overwrite using -o). | +| \pipe_once | \| | \pipe_once | Send next result to a subprocess. | +| \timing | \t | \timing | Toggle timing of queries. | +| connect | \r | connect [database] | Reconnect to the server, optionally switching databases. | +| delimiter | | delimiter | Change end-of-statement delimiter. | +| exit | \q | exit | Exit. | +| help | \? | help [term] | Show this table, or search for help on a term. | +| nopager | \n | nopager | Disable pager; print to stdout. | +| notee | | notee | Stop writing results to an output file. | +| nowarnings | \w | nowarnings | Disable automatic warnings display. | +| pager | \P | pager [command] | Set pager to [command]. Print query results via pager. | +| prompt | \R | prompt | Change prompt format. | +| quit | \q | quit | Quit. | +| redirectformat | \Tr | redirectformat | Change the table format used to output redirected results. | +| rehash | \# | rehash | Refresh auto-completions. | +| source | \. | source | Execute queries from a file. | +| status | \s | status | Get status information from the server. | +| system | | system [-r] | Execute a system shell command (raw mode with -r). | +| tableformat | \T | tableformat | Change the table format used to output interactive results. | +| tee | | tee [-o] | Append all results to an output file (overwrite using -o). | +| use | \u | use | Change to a new database. | +| warnings | \W | warnings | Enable automatic warnings display. | +| watch | | watch [seconds] [-c] | Execute query every [seconds] seconds (5 by default). | ++----------------+----------+---------------------------------+-------------------------------------------------------------+ diff --git a/test/features/fixture_utils.py b/test/features/fixture_utils.py index f85e0f65..0e624c2f 100644 --- a/test/features/fixture_utils.py +++ b/test/features/fixture_utils.py @@ -1,5 +1,6 @@ +# type: ignore + import os -import io def read_fixture_lines(filename): @@ -20,9 +21,9 @@ def read_fixture_files(): fixture_dict = {} current_dir = os.path.dirname(__file__) - fixture_dir = os.path.join(current_dir, 'fixture_data/') + fixture_dir = os.path.join(current_dir, "fixture_data/") for filename in os.listdir(fixture_dir): - if filename not in ['.', '..']: + if filename not in [".", ".."]: fullname = os.path.join(fixture_dir, filename) fixture_dict[filename] = read_fixture_lines(fullname) diff --git a/test/features/iocommands.feature b/test/features/iocommands.feature index 95366eba..3a523c39 100644 --- a/test/features/iocommands.feature +++ b/test/features/iocommands.feature @@ -27,15 +27,15 @@ Feature: I/O commands Scenario: set delimiter and query on same line When we query "select 123; delimiter $ select 456 $ delimiter %" - then we see result "123" - and we see result "456" + then we see tabular result "123" + and we see tabular result "456" and delimiter is set to "%" Scenario: send output to file When we query "\o /tmp/output1.sql" and we query "select 123" and we query "system cat /tmp/output1.sql" - then we see result "123" + then we see csv result "123" Scenario: send output to file two times When we query "\o /tmp/output1.sql" @@ -43,5 +43,56 @@ Feature: I/O commands and we query "\o /tmp/output2.sql" and we query "select 456" and we query "system cat /tmp/output2.sql" - then we see result "456" - \ No newline at end of file + then we see csv result "456" + + Scenario: shell style redirect to file + When we query "select 123 as constant $> /tmp/output1.csv" + and we query "system cat /tmp/output1.csv" + then we see csv 123 in file output + + Scenario: shell style redirect to command + When we query "select 100 $| wc" + then we see space 12 in command output + + Scenario: shell style redirect to multiple commands + When we query "select 100 $| head -1 $| wc" + then we see space 6 in command output + + Scenario: shell style redirect to multiple commands with minimal spaces + When we query "select 100$|head -1$|wc" + then we see space 6 in command output + + Scenario: shell style redirect to multiple commands containing single quotes + When we query "select 100 $| head '-1' $| wc" + then we see space 6 in command output + + Scenario: shell style redirect to multiple commands containing single quotes and minimal spaces + When we query "select 100$|head '-1'$|wc" + then we see space 6 in command output + + Scenario: shell style redirect to multiple commands containing mixed quoted and unquoted arg + When we query "select 100 $| head -'1' $| wc" + then we see space 6 in command output + + Scenario: shell style redirect to multiple commands containing double quotes + When we query "select 100 $| head ""-1"" $| wc" + then we see space 6 in command output + + Scenario: shell style redirect with commands and capture to file + When we query "select 100 $| head -1 $| wc $> /tmp/output1.txt" + and we query "system cat /tmp/output1.txt" + then we see text 6 in file output + + Scenario: shell style redirect with append to file + When we query "select 100 $> /tmp/output1.csv" + and we query "select 200 $>> /tmp/output1.csv" + and we query "system cat /tmp/output1.csv" + then we see csv 100 in file output + and we see csv 200 in file output + + Scenario: shell style redirect with command and append to file + When we query "select 300 $| grep 0 $> /tmp/output1.csv" + and we query "select 400 $| grep 0 $>> /tmp/output1.csv" + and we query "system cat /tmp/output1.csv" + then we see csv 300 in file output + and we see csv 400 in file output diff --git a/test/features/steps/auto_vertical.py b/test/features/steps/auto_vertical.py index e1cb26f8..33b43375 100644 --- a/test/features/steps/auto_vertical.py +++ b/test/features/steps/auto_vertical.py @@ -1,46 +1,55 @@ +# type: ignore + from textwrap import dedent from behave import then, when - -import wrappers from utils import parse_cli_args_to_dict +import wrappers -@when('we run dbcli with {arg}') +@when("we run dbcli with {arg}") def step_run_cli_with_arg(context, arg): wrappers.run_cli(context, run_args=parse_cli_args_to_dict(arg)) -@when('we execute a small query') +@when("we execute a small query") def step_execute_small_query(context): - context.cli.sendline('select 1') + context.cli.sendline("select 1") -@when('we execute a large query') +@when("we execute a large query") def step_execute_large_query(context): - context.cli.sendline( - 'select {}'.format(','.join([str(n) for n in range(1, 50)]))) + context.cli.sendline(f"select {','.join([str(n) for n in range(1, 50)])}") -@then('we see small results in horizontal format') +@then("we see small results in horizontal format") def step_see_small_results(context): - wrappers.expect_pager(context, dedent("""\ - +---+\r - | 1 |\r - +---+\r - | 1 |\r - +---+\r - \r - """), timeout=5) - wrappers.expect_exact(context, '1 row in set', timeout=2) - - -@then('we see large results in vertical format') + expected = ( + dedent( + """ + +---+\r + | 1 |\r + +---+\r + | 1 |\r + +---+ + """ + ).strip() + + '\r\n\r\n' + ) + + wrappers.expect_pager( + context, + expected, + timeout=5, + ) + wrappers.expect_exact(context, "1 row in set", timeout=2) + + +@then("we see large results in vertical format") def step_see_large_results(context): - rows = ['{n:3}| {n}'.format(n=str(n)) for n in range(1, 50)] - expected = ('***************************[ 1. row ]' - '***************************\r\n' + - '{}\r\n'.format('\r\n'.join(rows) + '\r\n')) + rows = [f"{str(n):3}| {n}" for n in range(1, 50)] + delimited_rows = '\r\n'.join(rows) + '\r\n' + expected = "***************************[ 1. row ]***************************\r\n" + delimited_rows + "\r\n" wrappers.expect_pager(context, expected, timeout=10) - wrappers.expect_exact(context, '1 row in set', timeout=2) + wrappers.expect_exact(context, "1 row in set", timeout=2) diff --git a/test/features/steps/basic_commands.py b/test/features/steps/basic_commands.py index 425ef674..f94d4937 100644 --- a/test/features/steps/basic_commands.py +++ b/test/features/steps/basic_commands.py @@ -1,3 +1,5 @@ +# type: ignore + """Steps for behavioral style tests are defined in this module. Each step is defined by the string decorating it. This string is used @@ -5,18 +7,22 @@ """ -from behave import when -from textwrap import dedent +import datetime import tempfile +from textwrap import dedent + +from behave import then, when import wrappers +from test.utils import TEMPFILE_PREFIX -@when('we run dbcli') + +@when("we run dbcli") def step_run_cli(context): wrappers.run_cli(context) -@when('we wait for prompt') +@when("we wait for prompt") def step_wait_prompt(context): wrappers.wait_prompt(context) @@ -24,77 +30,114 @@ def step_wait_prompt(context): @when('we send "ctrl + d"') def step_ctrl_d(context): """Send Ctrl + D to hopefully exit.""" - context.cli.sendcontrol('d') + context.cli.sendcontrol("d") context.exit_sent = True -@when('we send "\?" command') +@when('we send "ctrl + o, ctrl + d"') +def step_ctrl_o_ctrl_d(context): + """Send ctrl + o, ctrl + d to insert the quoted date.""" + context.cli.send("SELECT ") + context.cli.sendcontrol("o") + context.cli.sendcontrol("d") + context.cli.send(" AS dt") + context.cli.sendline("") + + +@when(r'we send "\?" command') def step_send_help(context): - """Send \? + r"""Send \? to see help. """ - context.cli.sendline('\\?') - wrappers.expect_exact( - context, context.conf['pager_boundary'] + '\r\n', timeout=5) + context.cli.sendline("\\?") + wrappers.expect_exact(context, context.conf["pager_boundary"] + "\r\n", timeout=5) -@when(u'we send source command') +@when("we send source command") def step_send_source_command(context): - with tempfile.NamedTemporaryFile() as f: - f.write(b'\?') + with tempfile.NamedTemporaryFile(prefix=TEMPFILE_PREFIX) as f: + f.write(b"\\?") f.flush() - context.cli.sendline('\. {0}'.format(f.name)) - wrappers.expect_exact( - context, context.conf['pager_boundary'] + '\r\n', timeout=5) + context.cli.sendline(f"\\. {f.name}") + wrappers.expect_exact(context, context.conf["pager_boundary"] + "\r\n", timeout=5) -@when(u'we run query to check application_name') +@when("we run query to check application_name") def step_check_application_name(context): context.cli.sendline( - "SELECT 'found' FROM performance_schema.session_connect_attrs WHERE attr_name = 'program_name' AND attr_value = 'mycli'" + "SELECT 'found' FROM performance_schema.session_connect_attrs WHERE attr_name = 'program_name' AND attr_value = 'mycli' LIMIT 1" ) -@then(u'we see found') +@then("we see found") def step_see_found(context): - wrappers.expect_exact( - context, - context.conf['pager_boundary'] + '\r' + dedent(''' + expected = ( + dedent( + """ +-------+\r | found |\r +-------+\r | found |\r - +-------+\r - \r - ''') + context.conf['pager_boundary'], - timeout=5 + +-------+ + """ + ).strip() + + '\r\n\r\n' ) + wrappers.expect_exact( + context, + context.conf["pager_boundary"] + '\r\n' + expected + context.conf["pager_boundary"], + timeout=5, + ) + + +@then("we see the date") +def step_see_date(context): + # There are some edge cases in which this test could fail, + # such as running near midnight when the test database has + # a different TZ setting than the system. + date_str = datetime.datetime.now().strftime("%Y-%m-%d") + expected = ( + dedent( + f""" + +------------+\r + | dt |\r + +------------+\r + | {date_str} |\r + +------------+ + """ + ).strip() + + '\r\n\r\n' + ) -@then(u'we confirm the destructive warning') -def step_confirm_destructive_command(context): - """Confirm destructive command.""" wrappers.expect_exact( - context, 'You\'re about to run a destructive command.\r\nDo you want to proceed? (y/n):', timeout=2) - context.cli.sendline('y') + context, + context.conf["pager_boundary"] + '\r\n' + expected + context.conf["pager_boundary"], + timeout=5, + ) -@when(u'we answer the destructive warning with "{confirmation}"') -def step_confirm_destructive_command(context, confirmation): +@then("we confirm the destructive warning") +def step_confirm_destructive_command(context): # noqa """Confirm destructive command.""" - wrappers.expect_exact( - context, 'You\'re about to run a destructive command.\r\nDo you want to proceed? (y/n):', timeout=2) + wrappers.expect_exact(context, "You're about to run a destructive command.\r\nDo you want to proceed? (y/n):", timeout=2) + context.cli.sendline("y") + + +@when('we answer the destructive warning with "{confirmation}"') +def step_confirm_destructive_command(context, confirmation): # noqa + """Confirm destructive command.""" + wrappers.expect_exact(context, "You're about to run a destructive command.\r\nDo you want to proceed? (y/n):", timeout=2) context.cli.sendline(confirmation) -@then(u'we answer the destructive warning with invalid "{confirmation}" and see text "{text}"') -def step_confirm_destructive_command(context, confirmation, text): +@then('we answer the destructive warning with invalid "{confirmation}" and see text "{text}"') +def step_confirm_destructive_command(context, confirmation, text): # noqa """Confirm destructive command.""" - wrappers.expect_exact( - context, 'You\'re about to run a destructive command.\r\nDo you want to proceed? (y/n):', timeout=2) + wrappers.expect_exact(context, "You're about to run a destructive command.\r\nDo you want to proceed? (y/n):", timeout=2) context.cli.sendline(confirmation) wrappers.expect_exact(context, text, timeout=2) # we must exit the Click loop, or the feature will hang - context.cli.sendline('n') + context.cli.sendline("n") diff --git a/test/features/steps/connection.py b/test/features/steps/connection.py index e16dd867..dbc1eb4d 100644 --- a/test/features/steps/connection.py +++ b/test/features/steps/connection.py @@ -1,71 +1,54 @@ +# type: ignore + import io import os -import shlex - -from behave import when, then -import pexpect +from behave import then, when import wrappers -from test.features.steps.utils import parse_cli_args_to_dict -from test.features.environment import MY_CNF_PATH, MYLOGIN_CNF_PATH, get_db_name_from_context -from test.utils import HOST, PORT, USER, PASSWORD -from mycli.config import encrypt_mylogin_cnf +from mycli.config import encrypt_mylogin_cnf +from test.features.environment import MY_CNF_PATH, MYLOGIN_CNF_PATH, get_db_name_from_context +from test.features.steps.utils import parse_cli_args_to_dict +from test.utils import HOST, PASSWORD, PORT, USER -TEST_LOGIN_PATH = 'test_login_path' +TEST_LOGIN_PATH = "test_login_path" @when('we run mycli with arguments "{exact_args}" without arguments "{excluded_args}"') @when('we run mycli without arguments "{excluded_args}"') -def step_run_cli_without_args(context, excluded_args, exact_args=''): - wrappers.run_cli( - context, - run_args=parse_cli_args_to_dict(exact_args), - exclude_args=parse_cli_args_to_dict(excluded_args).keys() - ) +def step_run_cli_without_args(context, excluded_args, exact_args=""): + wrappers.run_cli(context, run_args=parse_cli_args_to_dict(exact_args), exclude_args=parse_cli_args_to_dict(excluded_args).keys()) @then('status contains "{expression}"') def status_contains(context, expression): - wrappers.expect_exact(context, f'{expression}', timeout=5) + wrappers.expect_exact(context, f"{expression}", timeout=5) # Normally, the shutdown after scenario waits for the prompt. # But we may have changed the prompt, depending on parameters, # so let's wait for its last character - context.cli.expect_exact('>') + context.cli.expect_exact(">") context.atprompt = True -@when('we create my.cnf file') +@when("we create my.cnf file") def step_create_my_cnf_file(context): - my_cnf = ( - '[client]\n' - f'host = {HOST}\n' - f'port = {PORT}\n' - f'user = {USER}\n' - f'password = {PASSWORD}\n' - ) - with open(MY_CNF_PATH, 'w') as f: + my_cnf = f"[client]\nhost = {HOST}\nport = {PORT}\nuser = {USER}\npassword = {PASSWORD}\n" + with open(MY_CNF_PATH, "w") as f: f.write(my_cnf) -@when('we create mylogin.cnf file') +@when("we create mylogin.cnf file") def step_create_mylogin_cnf_file(context): - os.environ.pop('MYSQL_TEST_LOGIN_FILE', None) - mylogin_cnf = ( - f'[{TEST_LOGIN_PATH}]\n' - f'host = {HOST}\n' - f'port = {PORT}\n' - f'user = {USER}\n' - f'password = {PASSWORD}\n' - ) - with open(MYLOGIN_CNF_PATH, 'wb') as f: + os.environ.pop("MYSQL_TEST_LOGIN_FILE", None) + mylogin_cnf = f"[{TEST_LOGIN_PATH}]\nhost = {HOST}\nport = {PORT}\nuser = {USER}\npassword = {PASSWORD}\n" + with open(MYLOGIN_CNF_PATH, "wb") as f: input_file = io.StringIO(mylogin_cnf) f.write(encrypt_mylogin_cnf(input_file).read()) -@then('we are logged in') +@then("we are logged in") def we_are_logged_in(context): db_name = get_db_name_from_context(context) - context.cli.expect_exact(f'{db_name}>', timeout=5) + context.cli.expect_exact(f"{db_name}>", timeout=5) context.atprompt = True diff --git a/test/features/steps/crud_database.py b/test/features/steps/crud_database.py index 841f37d0..3356a112 100644 --- a/test/features/steps/crud_database.py +++ b/test/features/steps/crud_database.py @@ -1,3 +1,5 @@ +# type: ignore + """Steps for behavioral style tests are defined in this module. Each step is defined by the string decorating it. This string is used @@ -5,111 +7,109 @@ """ +from behave import then, when import pexpect - import wrappers -from behave import when, then + +from mycli.constants import DEFAULT_DATABASE -@when('we create database') +@when("we create database") def step_db_create(context): """Send create database.""" - context.cli.sendline('create database {0};'.format( - context.conf['dbname_tmp'])) + context.cli.sendline(f"create database {context.conf['dbname_tmp']};") - context.response = { - 'database_name': context.conf['dbname_tmp'] - } + context.response = {"database_name": context.conf["dbname_tmp"]} -@when('we drop database') +@when("we drop database") def step_db_drop(context): """Send drop database.""" - context.cli.sendline('drop database {0};'.format( - context.conf['dbname_tmp'])) + context.cli.sendline(f"drop database {context.conf['dbname_tmp']};") -@when('we connect to test database') +@when("we connect to test database") def step_db_connect_test(context): """Send connect to database.""" - db_name = context.conf['dbname'] + db_name = context.conf["dbname"] context.currentdb = db_name - context.cli.sendline('use {0};'.format(db_name)) + context.cli.sendline(f"use {db_name};") -@when('we connect to quoted test database') +@when("we connect to quoted test database") def step_db_connect_quoted_tmp(context): """Send connect to database.""" - db_name = context.conf['dbname'] + db_name = context.conf["dbname"] context.currentdb = db_name - context.cli.sendline('use `{0}`;'.format(db_name)) + context.cli.sendline(f"use `{db_name}`;") -@when('we connect to tmp database') +@when("we connect to tmp database") def step_db_connect_tmp(context): """Send connect to database.""" - db_name = context.conf['dbname_tmp'] + db_name = context.conf["dbname_tmp"] context.currentdb = db_name - context.cli.sendline('use {0}'.format(db_name)) + context.cli.sendline(f"use {db_name}") -@when('we connect to dbserver') +@when("we connect to dbserver") def step_db_connect_dbserver(context): """Send connect to database.""" - context.currentdb = 'mysql' - context.cli.sendline('use mysql') + context.currentdb = DEFAULT_DATABASE + context.cli.sendline(f"use {DEFAULT_DATABASE}") -@then('dbcli exits') +@then("dbcli exits") def step_wait_exit(context): """Make sure the cli exits.""" wrappers.expect_exact(context, pexpect.EOF, timeout=5) -@then('we see dbcli prompt') +@then("we see dbcli prompt") def step_see_prompt(context): """Wait to see the prompt.""" - user = context.conf['user'] - host = context.conf['host'] + user = context.conf["user"] + host = context.conf["host"] dbname = context.currentdb - wrappers.wait_prompt(context, '{0}@{1}:{2}> '.format(user, host, dbname)) + wrappers.wait_prompt(context, f"{user}@{host}:{dbname}> ") -@then('we see help output') +@then("we see help output") def step_see_help(context): - for expected_line in context.fixture_data['help_commands.txt']: + for expected_line in context.fixture_data["help_commands.txt"]: + # in case tests are run without extras + if 'LLM' in expected_line: + continue wrappers.expect_exact(context, expected_line, timeout=1) -@then('we see database created') +@then("we see database created") def step_see_db_created(context): """Wait to see create database output.""" - wrappers.expect_exact(context, 'Query OK, 1 row affected', timeout=2) + wrappers.expect_exact(context, "Query OK, 1 row affected", timeout=2) -@then('we see database dropped') +@then("we see database dropped") def step_see_db_dropped(context): """Wait to see drop database output.""" - wrappers.expect_exact(context, 'Query OK, 0 rows affected', timeout=2) + wrappers.expect_exact(context, "Query OK, 0 rows affected", timeout=2) -@then('we see database dropped and no default database') +@then("we see database dropped and no default database") def step_see_db_dropped_no_default(context): """Wait to see drop database output.""" - user = context.conf['user'] - host = context.conf['host'] - database = '(none)' + user = context.conf["user"] + host = context.conf["host"] + database = "(none)" context.currentdb = None - wrappers.expect_exact(context, 'Query OK, 0 rows affected', timeout=2) - wrappers.wait_prompt(context, '{0}@{1}:{2}>'.format(user, host, database)) + wrappers.expect_exact(context, "Query OK, 0 rows affected", timeout=2) + wrappers.wait_prompt(context, f"{user}@{host}:{database}>") -@then('we see database connected') +@then("we see database connected") def step_see_db_connected(context): """Wait to see drop database output.""" - wrappers.expect_exact( - context, 'You are now connected to database "', timeout=2) + wrappers.expect_exact(context, 'connected to database "', timeout=2) wrappers.expect_exact(context, '"', timeout=2) - wrappers.expect_exact(context, ' as user "{0}"'.format( - context.conf['user']), timeout=2) + wrappers.expect_exact(context, f' as user "{context.conf["user"]}"', timeout=2) diff --git a/test/features/steps/crud_table.py b/test/features/steps/crud_table.py index f715f0ca..d76c6964 100644 --- a/test/features/steps/crud_table.py +++ b/test/features/steps/crud_table.py @@ -1,3 +1,5 @@ +# type: ignore + """Steps for behavioral style tests are defined in this module. Each step is defined by the string decorating it. This string is used @@ -5,108 +7,127 @@ """ -import wrappers -from behave import when, then from textwrap import dedent +from behave import then, when +import wrappers + -@when('we create table') +@when("we create table") def step_create_table(context): """Send create table.""" - context.cli.sendline('create table a(x text);') + context.cli.sendline("create table a(x text);") -@when('we insert into table') +@when("we insert into table") def step_insert_into_table(context): """Send insert into table.""" - context.cli.sendline('''insert into a(x) values('xxx');''') + context.cli.sendline("""insert into a(x) values('xxx');""") -@when('we update table') +@when("we update table") def step_update_table(context): """Send insert into table.""" - context.cli.sendline('''update a set x = 'yyy' where x = 'xxx';''') + context.cli.sendline("""update a set x = 'yyy' where x = 'xxx';""") -@when('we select from table') +@when("we select from table") def step_select_from_table(context): """Send select from table.""" - context.cli.sendline('select * from a;') + context.cli.sendline("select * from a;") -@when('we delete from table') +@when("we delete from table") def step_delete_from_table(context): """Send deete from table.""" - context.cli.sendline('''delete from a where x = 'yyy';''') + context.cli.sendline("""delete from a where x = 'yyy';""") -@when('we drop table') +@when("we drop table") def step_drop_table(context): """Send drop table.""" - context.cli.sendline('drop table a;') + context.cli.sendline("drop table a;") -@then('we see table created') +@then("we see table created") def step_see_table_created(context): """Wait to see create table output.""" - wrappers.expect_exact(context, 'Query OK, 0 rows affected', timeout=2) + wrappers.expect_exact(context, "Query OK, 0 rows affected", timeout=2) -@then('we see record inserted') +@then("we see record inserted") def step_see_record_inserted(context): """Wait to see insert output.""" - wrappers.expect_exact(context, 'Query OK, 1 row affected', timeout=2) + wrappers.expect_exact(context, "Query OK, 1 row affected", timeout=2) -@then('we see record updated') +@then("we see record updated") def step_see_record_updated(context): """Wait to see update output.""" - wrappers.expect_exact(context, 'Query OK, 1 row affected', timeout=2) + wrappers.expect_exact(context, "Query OK, 1 row affected", timeout=2) -@then('we see data selected') +@then("we see data selected") def step_see_data_selected(context): """Wait to see select output.""" - wrappers.expect_pager( - context, dedent("""\ + expected = ( + dedent( + """ +-----+\r | x |\r +-----+\r | yyy |\r - +-----+\r - \r - """), timeout=2) - wrappers.expect_exact(context, '1 row in set', timeout=2) + +-----+ + """ + ).strip() + + '\r\n\r\n' + ) + + wrappers.expect_pager( + context, + expected, + timeout=2, + ) + wrappers.expect_exact(context, "1 row in set", timeout=2) -@then('we see record deleted') +@then("we see record deleted") def step_see_data_deleted(context): """Wait to see delete output.""" - wrappers.expect_exact(context, 'Query OK, 1 row affected', timeout=2) + wrappers.expect_exact(context, "Query OK, 1 row affected", timeout=2) -@then('we see table dropped') +@then("we see table dropped") def step_see_table_dropped(context): """Wait to see drop output.""" - wrappers.expect_exact(context, 'Query OK, 0 rows affected', timeout=2) + wrappers.expect_exact(context, "Query OK, 0 rows affected", timeout=2) -@when('we select null') +@when("we select null") def step_select_null(context): """Send select null.""" - context.cli.sendline('select null;') + context.cli.sendline("select null;") -@then('we see null selected') +@then("we see null selected") def step_see_null_selected(context): """Wait to see null output.""" - wrappers.expect_pager( - context, dedent("""\ + expected = ( + dedent( + """ +--------+\r | NULL |\r +--------+\r | |\r - +--------+\r - \r - """), timeout=2) - wrappers.expect_exact(context, '1 row in set', timeout=2) + +--------+ + """ + ).strip() + + '\r\n\r\n' + ) + + wrappers.expect_pager( + context, + expected, + timeout=2, + ) + wrappers.expect_exact(context, "1 row in set", timeout=2) diff --git a/test/features/steps/iocommands.py b/test/features/steps/iocommands.py index bbabf431..0792e95f 100644 --- a/test/features/steps/iocommands.py +++ b/test/features/steps/iocommands.py @@ -1,105 +1,134 @@ -import os -import wrappers +# type: ignore -from behave import when, then +import os from textwrap import dedent +from behave import then, when +import wrappers + -@when('we start external editor providing a file name') +@when("we start external editor providing a file name") def step_edit_file(context): """Edit file with external editor.""" - context.editor_file_name = os.path.join( - context.package_root, 'test_file_{0}.sql'.format(context.conf['vi'])) + context.editor_file_name = os.path.join(context.package_root, f"test_file_{context.conf['vi']}.sql") if os.path.exists(context.editor_file_name): os.remove(context.editor_file_name) - context.cli.sendline('\e {0}'.format( - os.path.basename(context.editor_file_name))) - wrappers.expect_exact( - context, 'Entering Ex mode. Type "visual" to go to Normal mode.', timeout=2) - wrappers.expect_exact(context, '\r\n:', timeout=2) + context.cli.sendline(f"\\e {os.path.basename(context.editor_file_name)}") + wrappers.expect_exact(context, 'Entering Ex mode. Type "visual" to go to Normal mode.', timeout=4) + wrappers.expect_exact(context, "\r\n:", timeout=4) @when('we type "{query}" in the editor') def step_edit_type_sql(context, query): - context.cli.sendline('i') + context.cli.sendline("i") context.cli.sendline(query) - context.cli.sendline('.') - wrappers.expect_exact(context, '\r\n:', timeout=2) + context.cli.sendline(".") + wrappers.expect_exact(context, "\r\n:", timeout=4) -@when('we exit the editor') +@when("we exit the editor") def step_edit_quit(context): - context.cli.sendline('x') - wrappers.expect_exact(context, "written", timeout=2) + context.cli.sendline("x") + wrappers.expect_exact(context, "written", timeout=4) @then('we see "{query}" in prompt') def step_edit_done_sql(context, query): - for match in query.split(' '): + for match in query.split(" "): wrappers.expect_exact(context, match, timeout=5) # Cleanup the command line. - context.cli.sendcontrol('c') + context.cli.sendcontrol("c") # Cleanup the edited file. if context.editor_file_name and os.path.exists(context.editor_file_name): os.remove(context.editor_file_name) -@when(u'we tee output') +@when("we tee output") def step_tee_ouptut(context): - context.tee_file_name = os.path.join( - context.package_root, 'tee_file_{0}.sql'.format(context.conf['vi'])) + context.tee_file_name = os.path.join(context.package_root, f"tee_file_{context.conf['vi']}.sql") if os.path.exists(context.tee_file_name): os.remove(context.tee_file_name) - context.cli.sendline('tee {0}'.format( - os.path.basename(context.tee_file_name))) + context.cli.sendline(f"tee {os.path.basename(context.tee_file_name)}") -@when(u'we select "select {param}"') +@when('we select "select {param}"') def step_query_select_number(context, param): - context.cli.sendline(u'select {}'.format(param)) - wrappers.expect_pager(context, dedent(u"""\ - +{dashes}+\r - | {param} |\r - +{dashes}+\r - | {param} |\r - +{dashes}+\r - \r - """.format(param=param, dashes='-' * (len(param) + 2)) - ), timeout=5) - wrappers.expect_exact(context, '1 row in set', timeout=2) - - -@then(u'we see result "{result}"') -def step_see_result(context, result): - wrappers.expect_exact( + context.cli.sendline(f"select {param}") + expected = ( + dedent( + f""" + +{'-' * (len(param) + 2)}+\r + | {param} |\r + +{'-' * (len(param) + 2)}+\r + | {param} |\r + +{'-' * (len(param) + 2)}+ + """ + ).strip() + + '\r\n\r\n' + ) + + wrappers.expect_pager( context, - u"| {} |".format(result), - timeout=2 + expected, + timeout=5, ) + wrappers.expect_exact(context, "1 row in set", timeout=2) + + +@then('we see tabular result "{result}"') +def step_see_tabular_result(context, result): + wrappers.expect_exact(context, f'| {result} |', timeout=2) + + +@then('we see csv result "{result}"') +def step_see_csv_result(context, result): + wrappers.expect_exact(context, f'"{result}"', timeout=2) -@when(u'we query "{query}"') +@when('we query "{query}"') def step_query(context, query): context.cli.sendline(query) -@when(u'we notee output') +@when("we notee output") def step_notee_output(context): - context.cli.sendline('notee') + context.cli.sendline("notee") -@then(u'we see 123456 in tee output') +@then("we see 123456 in tee output") def step_see_123456_in_ouput(context): with open(context.tee_file_name) as f: - assert '123456' in f.read() + assert "123456" in f.read() if os.path.exists(context.tee_file_name): os.remove(context.tee_file_name) -@then(u'delimiter is set to "{delimiter}"') +@then('we see csv {result} in file output') +def step_see_csv_result_in_redirected_ouput(context, result): + wrappers.expect_exact(context, f'"{result}"', timeout=2) + temp_filename = "/tmp/output1.csv" + if os.path.exists(temp_filename): + os.remove(temp_filename) + + +@then('we see text {result} in file output') +def step_see_text_result_in_redirected_ouput(context, result): + wrappers.expect_exact(context, f' {result}', timeout=2) + temp_filename = "/tmp/output1.txt" + if os.path.exists(temp_filename): + os.remove(temp_filename) + + +@then("we see space 12 in command output") +def step_see_space_12_in_command_ouput(context): + wrappers.expect_exact(context, ' 12', timeout=2) + + +@then("we see space 6 in command output") +def step_see_space_6_in_command_ouput(context): + wrappers.expect_exact(context, ' 6', timeout=2) + + +@then('delimiter is set to "{delimiter}"') def delimiter_is_set(context, delimiter): - wrappers.expect_exact( - context, - u'Changed delimiter to {}'.format(delimiter), - timeout=2 - ) + wrappers.expect_exact(context, f"Changed delimiter to {delimiter}", timeout=2) diff --git a/test/features/steps/named_queries.py b/test/features/steps/named_queries.py index bc1f8663..ea53234c 100644 --- a/test/features/steps/named_queries.py +++ b/test/features/steps/named_queries.py @@ -1,3 +1,5 @@ +# type: ignore + """Steps for behavioral style tests are defined in this module. Each step is defined by the string decorating it. This string is used @@ -5,86 +7,83 @@ """ +from behave import then, when import wrappers -from behave import when, then -@when('we save a named query') +@when("we save a named query") def step_save_named_query(context): """Send \fs command.""" - context.cli.sendline('\\fs foo SELECT 12345') + context.cli.sendline("\\fs foo SELECT 12345") -@when('we use a named query') +@when("we use a named query") def step_use_named_query(context): """Send \f command.""" - context.cli.sendline('\\f foo') + context.cli.sendline("\\f foo") -@when('we delete a named query') +@when("we delete a named query") def step_delete_named_query(context): """Send \fd command.""" - context.cli.sendline('\\fd foo') + context.cli.sendline("\\fd foo") -@then('we see the named query saved') +@then("we see the named query saved") def step_see_named_query_saved(context): """Wait to see query saved.""" - wrappers.expect_exact(context, 'Saved.', timeout=2) + wrappers.expect_exact(context, "Saved.", timeout=2) -@then('we see the named query executed') +@then("we see the named query executed") def step_see_named_query_executed(context): """Wait to see select output.""" - wrappers.expect_exact(context, 'SELECT 12345', timeout=2) + wrappers.expect_exact(context, "SELECT 12345", timeout=2) -@then('we see the named query deleted') +@then("we see the named query deleted") def step_see_named_query_deleted(context): """Wait to see query deleted.""" - wrappers.expect_exact(context, 'foo: Deleted', timeout=2) + wrappers.expect_exact(context, "foo: Deleted", timeout=2) -@when('we save a named query with parameters') +@when("we save a named query with parameters") def step_save_named_query_with_parameters(context): """Send \fs command for query with parameters.""" context.cli.sendline('\\fs foo_args SELECT $1, "$2", "$3"') -@when('we use named query with parameters') +@when("we use named query with parameters") def step_use_named_query_with_parameters(context): """Send \f command with parameters.""" context.cli.sendline('\\f foo_args 101 second "third value"') -@then('we see the named query with parameters executed') +@then("we see the named query with parameters executed") def step_see_named_query_with_parameters_executed(context): """Wait to see select output.""" - wrappers.expect_exact( - context, 'SELECT 101, "second", "third value"', timeout=2) + wrappers.expect_exact(context, 'SELECT 101, "second", "third value"', timeout=2) -@when('we use named query with too few parameters') +@when("we use named query with too few parameters") def step_use_named_query_with_too_few_parameters(context): """Send \f command with missing parameters.""" - context.cli.sendline('\\f foo_args 101') + context.cli.sendline("\\f foo_args 101") -@then('we see the named query with parameters fail with missing parameters') +@then("we see the named query with parameters fail with missing parameters") def step_see_named_query_with_parameters_fail_with_missing_parameters(context): """Wait to see select output.""" - wrappers.expect_exact( - context, 'missing substitution for $2 in query:', timeout=2) + wrappers.expect_exact(context, "missing substitution for $2 in query:", timeout=2) -@when('we use named query with too many parameters') +@when("we use named query with too many parameters") def step_use_named_query_with_too_many_parameters(context): """Send \f command with extra parameters.""" - context.cli.sendline('\\f foo_args 101 102 103 104') + context.cli.sendline("\\f foo_args 101 102 103 104") -@then('we see the named query with parameters fail with extra parameters') +@then("we see the named query with parameters fail with extra parameters") def step_see_named_query_with_parameters_fail_with_extra_parameters(context): """Wait to see select output.""" - wrappers.expect_exact( - context, 'query does not have substitution parameter $4:', timeout=2) + wrappers.expect_exact(context, "query does not have substitution parameter $4:", timeout=2) diff --git a/test/features/steps/specials.py b/test/features/steps/specials.py index e8b99e3e..04c43b13 100644 --- a/test/features/steps/specials.py +++ b/test/features/steps/specials.py @@ -1,3 +1,5 @@ +# type: ignore + """Steps for behavioral style tests are defined in this module. Each step is defined by the string decorating it. This string is used @@ -5,14 +7,14 @@ """ +from behave import then, when import wrappers -from behave import when, then -@when('we refresh completions') +@when("we refresh completions") def step_refresh_completions(context): """Send refresh command.""" - context.cli.sendline('rehash') + context.cli.sendline("rehash") @then('we see text "{text}"') @@ -20,8 +22,8 @@ def step_see_text(context, text): """Wait to see given text message.""" wrappers.expect_exact(context, text, timeout=2) -@then('we see completions refresh started') + +@then("we see completions refresh started") def step_see_refresh_started(context): """Wait to see refresh output.""" - wrappers.expect_exact( - context, 'Auto-completion refresh started in the background.', timeout=2) + wrappers.expect_exact(context, "Auto-completion refresh started in the background.", timeout=2) diff --git a/test/features/steps/utils.py b/test/features/steps/utils.py index 1ae63d2b..7e634dde 100644 --- a/test/features/steps/utils.py +++ b/test/features/steps/utils.py @@ -1,11 +1,13 @@ +# type: ignore + import shlex def parse_cli_args_to_dict(cli_args: str): args_dict = {} for arg in shlex.split(cli_args): - if '=' in arg: - key, value = arg.split('=') + if "=" in arg: + key, value = arg.split("=") args_dict[key] = value else: args_dict[arg] = None diff --git a/test/features/steps/wrappers.py b/test/features/steps/wrappers.py index 6408f235..6c004df3 100644 --- a/test/features/steps/wrappers.py +++ b/test/features/steps/wrappers.py @@ -1,13 +1,11 @@ +# type: ignore + +from io import StringIO import re -import pexpect import sys import textwrap - -try: - from StringIO import StringIO -except ImportError: - from io import StringIO +import pexpect def expect_exact(context, expected, timeout): @@ -18,33 +16,29 @@ def expect_exact(context, expected, timeout): timedout = True if timedout: # Strip color codes out of the output. - actual = re.sub(r'\x1b\[([0-9A-Za-z;?])+[m|K]?', - '', context.cli.before) + actual = re.sub(r"\x1b\[([0-9A-Za-z;?])+[m|K]?", "", context.cli.before) raise Exception( - textwrap.dedent('''\ + textwrap.dedent( + f"""\ Expected: --- - {0!r} + {expected!r} --- Actual: --- - {1!r} + {actual!r} --- Full log: --- - {2!r} + {context.logfile.getvalue()!r} --- - ''').format( - expected, - actual, - context.logfile.getvalue() + """ ) ) def expect_pager(context, expected, timeout): - expect_exact(context, "{0}\r\n{1}{0}\r\n".format( - context.conf['pager_boundary'], expected), timeout=timeout) + expect_exact(context, f"{context.conf['pager_boundary']}\r\n{expected}{context.conf['pager_boundary']}\r\n", timeout=timeout) def run_cli(context, run_args=None, exclude_args=None): @@ -63,55 +57,47 @@ def add_arg(name, key, value): else: rendered_args.append(key) - if conf.get('host', None): - add_arg('host', '-h', conf['host']) - if conf.get('user', None): - add_arg('user', '-u', conf['user']) - if conf.get('pass', None): - add_arg('pass', '-p', conf['pass']) - if conf.get('port', None): - add_arg('port', '-P', str(conf['port'])) - if conf.get('dbname', None): - add_arg('dbname', '-D', conf['dbname']) - if conf.get('defaults-file', None): - add_arg('defaults_file', '--defaults-file', conf['defaults-file']) - if conf.get('myclirc', None): - add_arg('myclirc', '--myclirc', conf['myclirc']) - if conf.get('login_path'): - add_arg('login_path', '--login-path', conf['login_path']) + if conf.get("host", None): + add_arg("host", "-h", conf["host"]) + if conf.get("user", None): + add_arg("user", "-u", conf["user"]) + if conf.get("pass", None): + add_arg("pass", "-p", conf["pass"]) + if conf.get("port", None): + add_arg("port", "-P", str(conf["port"])) + if conf.get("dbname", None): + add_arg("dbname", "-D", conf["dbname"]) + if conf.get("defaults-file", None): + add_arg("defaults_file", "--defaults-file", conf["defaults-file"]) + if conf.get("myclirc", None): + add_arg("myclirc", "--myclirc", conf["myclirc"]) + if conf.get("login_path"): + add_arg("login_path", "--login-path", conf["login_path"]) for arg_name, arg_value in conf.items(): - if arg_name.startswith('-'): + if arg_name.startswith("-"): add_arg(arg_name, arg_name, arg_value) try: - cli_cmd = context.conf['cli_command'] + cli_cmd = context.conf["cli_command"] except KeyError: - cli_cmd = ( - '{0!s} -c "' - 'import coverage ; ' - 'coverage.process_startup(); ' - 'import mycli.main; ' - 'mycli.main.cli()' - '"' - ).format(sys.executable) + cli_cmd = f'{sys.executable} -c "import coverage ; coverage.process_startup(); import mycli.main; mycli.main.click_entrypoint()"' cmd_parts = [cli_cmd] + rendered_args - cmd = ' '.join(cmd_parts) + cmd = " ".join(cmd_parts) context.cli = pexpect.spawnu(cmd, cwd=context.package_root) context.logfile = StringIO() context.cli.logfile = context.logfile context.exit_sent = False - context.currentdb = context.conf['dbname'] + context.currentdb = context.conf["dbname"] def wait_prompt(context, prompt=None): """Make sure prompt is displayed.""" if prompt is None: - user = context.conf['user'] - host = context.conf['host'] + user = context.conf["user"] + host = context.conf["host"] dbname = context.currentdb - prompt = '{0}@{1}:{2}>'.format( - user, host, dbname), + prompt = (f"{user}@{host}:{dbname}>",) expect_exact(context, prompt, timeout=5) context.atprompt = True diff --git a/test/features/wrappager.py b/test/features/wrappager.py index 51d49095..b61a7d00 100755 --- a/test/features/wrappager.py +++ b/test/features/wrappager.py @@ -1,8 +1,9 @@ #!/usr/bin/env python + import sys -def wrappager(boundary): +def wrappager(boundary: str) -> None: print(boundary) while 1: buf = sys.stdin.read(2048) diff --git a/test/myclirc b/test/myclirc index 0c1a7ad3..680447e5 100644 --- a/test/myclirc +++ b/test/myclirc @@ -1,10 +1,29 @@ # vi: ft=dosini [main] +# Enable or disable the automatic displaying of warnings ("SHOW WARNINGS") +# after executing a SQL statement when applicable. +show_warnings = False + # Enables context sensitive auto-completion. If this is disabled the all # possible completions will be listed. smart_completion = True +# Minimum characters typed before offering completion suggestions. +# Suggestion: 3. +min_completion_trigger = 1 + +# Prefetch completion metadata for schemas in the background after launch. +# Possible values: +# always = prefetch all schemas (default) +# never = do not prefetch any schemas +# listed = prefetch only the schemas named in prefetch_schemas_list +prefetch_schemas_mode = always + +# Comma-separated list of schemas to prefetch when +# prefetch_schemas_mode = listed. Ignored in other modes. +prefetch_schemas_list = + # Multi-line mode allows breaking up the sql statements into multiple lines. If # this is set to True, then the end of the statements must have a semi-colon. # If this is set to False then sql statements can't be split into multiple @@ -16,6 +35,14 @@ multi_line = False # or "shutdown". destructive_warning = True +# Queries starting with these keywords will activate the destructive warning. +# UPDATE will not activate the warning if the statement includes a WHERE +# clause. +destructive_keywords = DROP SHUTDOWN DELETE TRUNCATE ALTER UPDATE + +# interactive query history location. +history_file = ~/.mycli-history + # log_file location. log_file = ~/.mycli.test.log @@ -30,20 +57,48 @@ log_level = DEBUG # Timing of sql statements and table rendering. timing = True +# Show the full SQL when running a favorite query. Set to False to hide. +show_favorite_query = True + # Beep after long-running queries are completed; 0 to disable. beep_after_seconds = 0 -# Table format. Possible values: ascii, double, github, -# psql, plain, simple, grid, fancy_grid, pipe, orgtbl, rst, mediawiki, html, -# latex, latex_booktabs, textile, moinmoin, jira, vertical, tsv, csv. -# Recommended: ascii +# Table format. Possible values: ascii, ascii_escaped, csv, csv-noheader, +# csv-tab, csv-tab-noheader, double, fancy_grid, github, grid, html, jira, +# jsonl, jsonl_escaped, latex, latex_booktabs, mediawiki, minimal, moinmoin, +# mysql, mysql_unicode, mysql_heavy, orgtbl, pipe, plain, psql, psql_unicode, +# rst, simple, sql-insert, sql-update, sql-update-1, sql-update-2, textile, +# tsv, tsv_noheader, vertical. +# Recommended: mysql_unicode. table_format = ascii +# Redirected otuput format +# Recommended: csv. +redirect_format = csv + +# How to display the missing value (ie NULL). Only certain table formats +# support configuring the missing value. CSV for example always uses the +# empty string, and JSON formats use native nulls. +null_string = + +# How to align numeric data in tabular output: right or left. +numeric_alignment = right + +# How to display binary values in tabular output: "hex", or "utf8". "utf8" +# means attempt to render valid UTF-8 sequences as strings, then fall back +# to hex rendering if not possible. +binary_display = hex + +# A command to run after a successful output redirect, with {} to be replaced +# with the escaped filename. Mac example: echo {} | pbcopy. Escaping is not +# reliable/safe on Windows. +post_redirect_command = "" + # Syntax coloring style. Possible values (many support the "-dark" suffix): # manni, igor, xcode, vim, autumn, vs, rrt, native, perldoc, borland, tango, emacs, # friendly, monokai, paraiso, colorful, murphy, bw, pastie, paraiso, trac, default, # fruity. -# Screenshots at http://mycli.net/syntax +# Screenshots at https://mycli.net/syntax # Can be further modified in [colors] syntax_style = default @@ -56,24 +111,78 @@ key_bindings = emacs wider_completion_menu = False # MySQL prompt -# \D - The full current date -# \d - Database name -# \h - Hostname of the server -# \m - Minutes of the current time -# \n - Newline -# \P - AM/PM -# \p - Port -# \R - The current time, in 24-hour military time (0-23) -# \r - The current time, standard 12-hour time (1-12) -# \s - Seconds of the current time -# \t - Product type (Percona, MySQL, MariaDB, TiDB) -# \A - DSN alias name (from the [alias_dsn] section) -# \u - Username -# \x1b[...m - insert ANSI escape sequence +# * \D - full current date, e.g. Sat Feb 14 15:55:48 2026 +# * \R - current hour in 24-hour time (00–23) +# * \r - current hour in 12-hour time (01–12) +# * \m - minutes of the current time +# * \s - seconds of the current time +# * \P - AM/PM +# * \d - selected database/schema +# * \h - hostname of the server +# * \H - shortened hostname of the server +# * \p - connection port +# * \j - connection socket basename +# * \J - full connection socket path +# * \k - connection socket basename OR the port +# * \K - full connection socket path OR the port +# * \T - connection SSL/TLS version +# * \t - database vendor (Percona, MySQL, MariaDB, TiDB) +# * \u - username +# * \w - number of warnings, or "(none)" (requires frequent trips to the server) +# * \W - number of warnings, or the empty string (requires frequent trips to the server) +# * \y - uptime in seconds (requires frequent trips to the server) +# * \Y - uptime in words (requires frequent trips to the server) +# * \A - DSN alias +# * \n - a newline +# * \_ - a space +# * \\ - a literal backslash +# * \x1b[...m - an ANSI escape sequence (can style with color or attributes) +# ANSI color example: prompt = '\x1b[31mroot\x1b[0m@localhost:\d> ' +# * \ - a leading sequence indicating that the rest of the prompt be styled like HTML. +# See https://python-prompt-toolkit.readthedocs.io/en/stable/pages/printing_text.html#html . +# Characters such as "&" or literal "<" and ">" must be HTML-escaped in this mode +# HTML styles cannot be combined with ANSI sequences. HTML mode takes precedence. +# HTML color example: prompt = '\root@localhost:\d> ' +# prompt = "\t \u@\h:\d> " prompt_continuation = -> -# Skip intro info on startup and outro info on exit +# Use the same prompt format strings to construct a status line in the toolbar, +# where \B in the first position refers to the default toolbar showing keystrokes +# and state. Example: +# +# toolbar = '\B\d \D' +# +# If \B is included, the additional content will begin on the next line. More +# lines can be added with \n. If \B is not included, the customized toolbar +# can be a single line. An empty value is the same as the default "\B". The +# special literal value "None" will suppress the toolbar from appearing. +toolbar = '' + +# Use the same prompt format strings to construct a terminal tab title. +# The original XTerm docs call this title the "window title", but it now +# probably refers to a terminal tab. This title is only updated as frequently +# as the database is changed. +terminal_tab_title = '' + +# Use the same prompt format strings to construct a terminal window title. +# The original XTerm docs call this title the "icon title", but it now +# probably refers to a terminal window which contains tabs. This title is +# only updated as frequently as the database is changed. +terminal_window_title = '' + +# Use the same prompt format strings to construct a window title in a terminal +# multiplexer. Currently only tmux is supported. This title is only updated +# as frequently as the database is changed. +multiplex_window_title = '' + +# Use the same prompt format strings to construct a pane title in a terminal +# multiplexer. Currently only tmux is supported. This title is only updated +# as frequently as the database is changed. +multiplex_pane_title = '' + +# Skip intro info on startup and outro info on exit, and generally reduce +# feedback. This is equivalent to giving --quiet at the command line. less_chatty = True # Use alias from --login-path instead of host name in prompt @@ -89,36 +198,162 @@ keyword_casing = auto # disabled pager on startup enable_pager = True -# Custom colors for the completion menu, toolbar, etc. +# Choose a specific pager +pager = less + +# whether to show verbose warnings about the transition away from reading my.cnf +my_cnf_transition_done = False + +# Whether to store and retrieve passwords from the system keyring. +# See the documentation for https://pypi.org/project/keyring/ for your OS. +# Note that the hostname is considered to be different if short or qualified. +# This can be overridden with --use-keyring= at the CLI. +# A password can be reset with --use-keyring=reset at the CLI. +use_keyring = False + +[search] + +# Whether to apply syntax highlighting to the preview window in fuzzy history +# search. There is a small performance penalty to enabling this. The "pygmentize" +# CLI tool must also be available. The syntax style from the "syntax_style" +# option will be respected, though additional customizations from [colors] will +# not be applied. +highlight_preview = False + +[connection] + +# character set for connections without --character-set being set +default_character_set = utf8mb4 + +# whether to enable LOAD DATA LOCAL INFILE for connections without --local-infile being set +default_local_infile = False + +# How often to send periodic background pings to the server when input is idle. Ticks are +# roughly in seconds, but may be faster. Set to zero to disable. Suggestion: 300. +default_keepalive_ticks = 0 + +# Sets the desired behavior for handling secure connections to the database server. +# Possible values: +# auto = SSL is preferred for TCP/IP connections. Will attempt to connect via SSL, but will fall +# back to cleartext as needed. Will not attempt to connect with SSL over local sockets. +# on = SSL is required. Will attempt to connect via SSL even on a local socket, and will fail if +# a secure connection is not established. +# off = do not use SSL. Will fail if the server requires a secure connection. +default_ssl_mode = auto + +# SSL CA file for connections without --ssl-ca being set +default_ssl_ca = + +# SSL CA directory for connections without --ssl-capath being set +default_ssl_capath = + +# SSL X509 cert path for connections without --ssl-cert being set +default_ssl_cert = + +# SSL X509 key for connections without --ssl-key being set +default_ssl_key = + +# SSL cipher to use for connections without --ssl-cipher being set +default_ssl_cipher = + +# whether to verify server's "Common Name" in its cert, for connections without +# --ssl-verify-server-cert being set +default_ssl_verify_server_cert = False + +[llm] + +# If set to a positive integer, truncate text/binary fields to that width +# in bytes when sending sample data, to conserve tokens. Suggestion: 1024. +prompt_field_truncate = None + +# If set to a positive integer, attempt to truncate various sections of LLM +# prompt input to that number in bytes, to conserve tokens. Suggestion: +# 1000000. +prompt_section_truncate = None + +[keys] + +# possible values: exit, none +control_d = exit + +# possible values: auto, fzf, reverse_isearch +control_r = auto + +# comma-separated list: toolkit_default, summon, advancing_summon, prefixing_summon, advance, cancel +# +# * toolkit_default - ignore other behaviors and use prompt_toolkit's default bindings +# * summon - when completions are not visible, summon them +# * advancing_summon - when completions are not visible, summon them _and_ advance in the list +# * prefixing_summon - when completions are not visible, summon them _and_ insert the common prefix +# * advance - when completions are visible, advance in the list +# * cancel - when completions are visible, toggle the list off +control_space = summon, advance + +# comma-separated list: toolkit_default, summon, advancing_summon, prefixing_summon, advance, cancel +tab = advancing_summon, advance + +# How long to wait for an Escape key sequence in vi mode. +# 0.5 seconds is the prompt_toolkit default, but vi users may find that too long. +# Shorter values mean that "Escape" alone is recognized more quickly. +vi_ttimeoutlen = 0.1 + +# How long to wait for an Escape key sequence in Emacs mode. +emacs_ttimeoutlen = 0.5 + +# Custom colors for the completion menu, toolbar, etc, with actual support +# depending on the terminal, and the property being set. +# Colors: #ffffff, bg:#ffffff, border:#ffffff. +# Attributes: (no)blink, bold, dim, hidden, inherit, italic, reverse, strike, underline. [colors] -completion-menu.completion.current = "bg:#ffffff #000000" -completion-menu.completion = "bg:#008888 #ffffff" -completion-menu.meta.completion.current = "bg:#44aaaa #000000" -completion-menu.meta.completion = "bg:#448888 #ffffff" -completion-menu.multi-column-meta = "bg:#aaffff #000000" -scrollbar.arrow = "bg:#003333" -scrollbar = "bg:#00aaaa" -selected = "#ffffff bg:#6666aa" -search = "#ffffff bg:#4444aa" -search.current = "#ffffff bg:#44aa44" -bottom-toolbar = "bg:#222222 #aaaaaa" -bottom-toolbar.off = "bg:#222222 #888888" -bottom-toolbar.on = "bg:#222222 #ffffff" -search-toolbar = noinherit bold -search-toolbar.text = nobold -system-toolbar = noinherit bold -arg-toolbar = noinherit bold -arg-toolbar.text = nobold -bottom-toolbar.transaction.valid = "bg:#222222 #00ff5f bold" -bottom-toolbar.transaction.failed = "bg:#222222 #ff005f bold" - -# style classes for colored table output +# Completion menus +completion-menu.completion.current = 'bg:#ffffff #000000' +completion-menu.completion = 'bg:#008888 #ffffff' +completion-menu.meta.completion.current = 'bg:#44aaaa #000000' +completion-menu.meta.completion = 'bg:#448888 #ffffff' +completion-menu.multi-column-meta = 'bg:#aaffff #000000' +scrollbar.arrow = 'bg:#003333' +scrollbar = 'bg:#00aaaa' + +# The prompt +prompt = '' +continuation = '' + +# Colored table output (query results) +output.table-separator = "" output.header = "#00ff5f bold" output.odd-row = "" output.even-row = "" output.null = "#808080" +output.status = "" +output.status.warning-count = "" +output.timing = "" + +# Selected text (native selection; currently unused) +selected = '#ffffff bg:#6666aa' + +# Search matches (for reverse i-search, not fuzzy search) +search = '#ffffff bg:#4444aa' +search.current = '#ffffff bg:#44aa44' -# SQL syntax highlighting overrides +# UI elements: bottom toolbar +bottom-toolbar = 'bg:#222222 #aaaaaa' +bottom-toolbar.off = 'bg:#222222 #888888' +bottom-toolbar.on = 'bg:#222222 #ffffff' +bottom-toolbar.transaction.valid = 'bg:#222222 #00ff5f bold' +bottom-toolbar.transaction.failed = 'bg:#222222 #ff005f bold' + +# UI elements: other toolbars (currently unused) +search-toolbar = 'noinherit bold' +search-toolbar.text = 'nobold' +system-toolbar = 'noinherit bold' +arg-toolbar = 'noinherit bold' +arg-toolbar.text = 'nobold' + +# SQL enhacements: matching brackets +matching-bracket.cursor = '#ff8888 bg:#880000' +matching-bracket.other = '#000000 bg:#aacccc' + +# SQL syntax highlighting overrides: normally defined by main.syntax_style # sql.comment = 'italic #408080' # sql.comment.multi-line = '' # sql.comment.single-line = '' @@ -148,10 +383,25 @@ output.null = "#808080" # sql.whitespace = '' # Favorite queries. +# You can add your favorite queries here. They will be available in the +# REPL when you type `\f` or `\f `. [favorite_queries] check = 'select "✔"' +foo_args = 'SELECT $1, "$2", "$3"' +# example = "SELECT * FROM example_table WHERE id = 1" + +# Initial commands to execute when connecting to any database. +[init-commands] +global_limit = set sql_select_limit=9999 +# read_only = "SET SESSION TRANSACTION READ ONLY" + # Use the -d option to reference a DSN. # Special characters in passwords and other strings can be escaped with URL encoding. [alias_dsn] # example_dsn = mysql://[user[:password]@][host][:port][/dbname] + +# Initial commands to execute when connecting to a DSN alias. +[alias_dsn.init-commands] +# Define one or more SQL statements per alias (semicolon-separated). +# example_dsn = "SET sql_select_limit=1000; SET time_zone='+00:00'" diff --git a/test/pytests/conftest.py b/test/pytests/conftest.py new file mode 100644 index 00000000..7cecff4d --- /dev/null +++ b/test/pytests/conftest.py @@ -0,0 +1,41 @@ +# type: ignore + +import pytest + +import mycli.sqlexecute +from test.utils import CHARACTER_SET, DATABASE, HOST, PASSWORD, PORT, SSH_HOST, SSH_PORT, SSH_USER, USER, create_db, db_connection + + +@pytest.fixture(scope="function") +def connection(): + create_db(DATABASE) + connection = db_connection(DATABASE) + yield connection + + connection.close() + + +@pytest.fixture +def cursor(connection): + with connection.cursor() as cur: + return cur + + +@pytest.fixture +def executor(connection): + return mycli.sqlexecute.SQLExecute( + database=DATABASE, + user=USER, + host=HOST, + password=PASSWORD, + port=PORT, + socket=None, + character_set=CHARACTER_SET, + local_infile=False, + ssl=None, + ssh_user=SSH_USER, + ssh_host=SSH_HOST, + ssh_port=SSH_PORT, + ssh_password=None, + ssh_key_filename=None, + ) diff --git a/test/pytests/test_app_state.py b/test/pytests/test_app_state.py new file mode 100644 index 00000000..c1f61aca --- /dev/null +++ b/test/pytests/test_app_state.py @@ -0,0 +1,146 @@ +from __future__ import annotations + +from typing import Any + +from configobj import ConfigObj +import pytest + +from mycli.app_state import ( + AppStateMixin, + destructive_keywords_from_config, + ensure_my_cnf_sections, + llm_prompt_truncation, + normalize_ssl_mode, +) + + +class AppState(AppStateMixin): + def __init__(self, defaults_suffix: str | None = None, login_path: str | None = None) -> None: + self.defaults_suffix = defaults_suffix + self.login_path = login_path + + +@pytest.mark.parametrize('ssl_mode', ['auto', 'on', 'off']) +def test_normalize_ssl_mode_accepts_known_values(ssl_mode: str) -> None: + config = ConfigObj({'main': {'ssl_mode': ssl_mode}, 'connection': {'default_ssl_mode': 'off'}}) + + assert normalize_ssl_mode(config) == (ssl_mode, None) + + +def test_normalize_ssl_mode_falls_back_to_connection_default() -> None: + config = ConfigObj({'main': {'ssl_mode': ''}, 'connection': {'default_ssl_mode': 'on'}}) + + assert normalize_ssl_mode(config) == ('on', None) + + +def test_normalize_ssl_mode_reports_invalid_values() -> None: + config = ConfigObj({'main': {'ssl_mode': 'required'}, 'connection': {'default_ssl_mode': 'off'}}) + + ssl_mode, warning = normalize_ssl_mode(config) + + assert ssl_mode is None + assert warning == 'Invalid config option provided for ssl_mode (required); ignoring.' + + +def test_ensure_my_cnf_sections_adds_missing_sections() -> None: + config = ConfigObj({'client': {'user': 'alice'}, 'extra': {'port': '3307'}}) + + ensure_my_cnf_sections(config) + + assert config['client'] == {'user': 'alice'} + assert config['mysqld'] == {} + assert config['extra'] == {'port': '3307'} + + +def test_destructive_keywords_from_config_splits_non_empty_words() -> None: + config = ConfigObj({'main': {'destructive_keywords': 'DROP DELETE UPDATE'}}) + + assert destructive_keywords_from_config(config) == ['DROP', 'DELETE', 'UPDATE'] + + +def test_destructive_keywords_from_config_uses_default() -> None: + config = ConfigObj({'main': {}}) + + assert destructive_keywords_from_config(config) == ['DROP', 'SHUTDOWN', 'DELETE', 'TRUNCATE', 'ALTER', 'UPDATE'] + + +@pytest.mark.parametrize( + ('llm_config', 'expected'), + [ + ({'prompt_field_truncate': '12', 'prompt_section_truncate': '34'}, (12, 34)), + ({'prompt_field_truncate': 'abc', 'prompt_section_truncate': '-1'}, (0, 0)), + ({}, (0, 0)), + ], +) +def test_llm_prompt_truncation_reads_positive_integer_strings( + llm_config: dict[str, str], + expected: tuple[int, int], +) -> None: + config = ConfigObj({'main': {}, 'llm': llm_config}) + + assert llm_prompt_truncation(config) == expected + + +def test_llm_prompt_truncation_handles_missing_llm_section() -> None: + assert llm_prompt_truncation(ConfigObj({'main': {}})) == (0, 0) + + +def test_read_my_cnf_reads_allowed_sections_and_strips_quotes() -> None: + app_state = AppState() + cnf = ConfigObj({ + 'client': {'host': '"db.example.com"', 'socket': '/tmp/client.sock'}, + 'mysqld': {'socket': "'/tmp/mysql.sock'", 'port': '3307', 'user': 'mysql'}, + 'ignored': {'host': 'ignored.example.com'}, + }) + + configuration = app_state.read_my_cnf(cnf, ['host', 'socket', 'port', 'user', 'password']) + + assert configuration == { + 'host': 'db.example.com', + 'socket': '/tmp/client.sock', + 'default_socket': '/tmp/mysql.sock', + 'default_port': '3307', + 'default_user': 'mysql', + } + assert configuration['password'] is None + + +def test_read_my_cnf_includes_login_path_and_suffix_sections() -> None: + app_state = AppState(defaults_suffix='test', login_path='work') + cnf = ConfigObj({ + 'client': {'user': 'client-user'}, + 'work': {'password': 'work-pass'}, + 'clienttest': {'host': 'client-test-host'}, + 'worktest': {'database': 'work-test-db'}, + }) + + configuration = app_state.read_my_cnf(cnf, ['user', 'password', 'host', 'database']) + + assert configuration == { + 'user': 'client-user', + 'password': 'work-pass', + 'host': 'client-test-host', + 'database': 'work-test-db', + } + + +def test_merge_ssl_with_cnf_keeps_existing_ssl_and_adds_cnf_values() -> None: + app_state = AppState() + ssl: dict[str, Any] = {'ca': 'existing-ca.pem', 'cert': 'existing-cert.pem'} + cnf = { + 'ssl-ca': 'cnf-ca.pem', + 'ssl-key': 'client-key.pem', + 'ssl-verify-server-cert': 'ON', + 'ssl-empty': None, + 'host': 'db.example.com', + } + + merged = app_state.merge_ssl_with_cnf(ssl, cnf) + + assert merged == { + 'ca': 'cnf-ca.pem', + 'cert': 'existing-cert.pem', + 'key': 'client-key.pem', + 'check_hostname': True, + } + assert ssl == {'ca': 'existing-ca.pem', 'cert': 'existing-cert.pem'} diff --git a/test/pytests/test_batch_utils.py b/test/pytests/test_batch_utils.py new file mode 100644 index 00000000..603d6ce9 --- /dev/null +++ b/test/pytests/test_batch_utils.py @@ -0,0 +1,101 @@ +# type: ignore + +from io import StringIO + +import pytest + +import mycli.packages.batch_utils +from mycli.packages.batch_utils import statements_from_filehandle + + +def collect_statements(sql: str) -> list[tuple[str, int]]: + return list(statements_from_filehandle(StringIO(sql))) + + +def test_statements_from_filehandle_splits_on_statements() -> None: + statements = collect_statements('select 1;\nselect\n 2;\nselect 3; select 4;\n') + + assert statements == [ + ('select 1;', 0), + ('select\n 2;', 1), + ('select 3;', 2), + ('select 4;', 3), + ] + + +def test_statements_from_filehandle_yields_trailing_statement_without_newline_01() -> None: + statements = collect_statements('select 1;\nselect 2;') + + assert statements == [ + ('select 1;', 0), + ('select 2;', 1), + ] + + +def test_statements_from_filehandle_yields_trailing_statement_without_newline_02() -> None: + statements = collect_statements('select 1;\nselect 2') + + assert statements == [ + ('select 1;', 0), + ('select 2', 1), + ] + + +def test_statements_from_filehandle_yields_trailing_statement_without_newline_03() -> None: + statements = collect_statements('select 1\nwhere 1 == 1;') + + assert statements == [('select 1\nwhere 1 == 1;', 0)] + + +def test_statements_from_filehandle_rejects_overlong_statement(monkeypatch) -> None: + monkeypatch.setattr(mycli.packages.batch_utils, 'MAX_MULTILINE_BATCH_STATEMENT', 2) + + with pytest.raises(ValueError, match='Saw single input statement greater than 2 lines'): + list(statements_from_filehandle(StringIO('select 1,\n2\nwhere 1 = 1;'))) + + +def test_statements_from_filehandle_yields_incorrect_sql() -> None: + statements = collect_statements('select;\nselect 2') + + assert statements == [ + ('select;', 0), + ('select 2', 1), + ] + + +def test_statements_from_filehandle_yields_invalid_sql_01() -> None: + statements = collect_statements('sellect;\nsellect 2') + + assert statements == [ + ('sellect;', 0), + ('sellect 2', 1), + ] + + +def test_statements_from_filehandle_yields_invalid_sql_02() -> None: + statements = collect_statements('select `column;') + + assert statements == [ + ('select `column;', 0), + ] + + +def test_statements_from_filehandle_continues_when_tokenizer_returns_no_tokens(monkeypatch) -> None: + tokenize_calls: list[str] = [] + original_tokenize = mycli.packages.batch_utils.sqlglot.tokenize + + def fake_tokenize(sql: str, read: str): + tokenize_calls.append(sql) + if len(tokenize_calls) == 1: + return [] + return original_tokenize(sql, read=read) + + monkeypatch.setattr(mycli.packages.batch_utils.sqlglot, 'tokenize', fake_tokenize) + + statements = list(statements_from_filehandle(StringIO('select 1;\nselect 2;'))) + + assert tokenize_calls[0] == 'select 1;\n' + assert statements == [ + ('select 1;', 0), + ('select 2;', 1), + ] diff --git a/test/pytests/test_checkup.py b/test/pytests/test_checkup.py new file mode 100644 index 00000000..1571c139 --- /dev/null +++ b/test/pytests/test_checkup.py @@ -0,0 +1,246 @@ +import importlib.metadata +import json +from types import SimpleNamespace +import urllib.error + +from mycli.main_modes import checkup + + +class FakeUrlResponse: + def __init__(self, payload: dict) -> None: + self.payload = payload + + def __enter__(self) -> 'FakeUrlResponse': + return self + + def __exit__(self, exc_type, exc, tb) -> None: + return None + + def read(self) -> bytes: + return json.dumps(self.payload).encode('utf8') + + +def test_pypi_api_fetch_success(monkeypatch) -> None: + def fake_urlopen(url: str, timeout: int) -> FakeUrlResponse: + assert url == 'https://pypi.org/pypi/mycli/json' + assert timeout == 5 + return FakeUrlResponse({'info': {'version': '1.2.3'}}) + + monkeypatch.setattr(checkup.urllib.request, 'urlopen', fake_urlopen) + + assert checkup.pypi_api_fetch('/mycli/json') == {'info': {'version': '1.2.3'}} + + +def test_pypi_api_fetch_url_error(monkeypatch, capsys) -> None: + def fake_urlopen(url: str, timeout: int) -> FakeUrlResponse: + raise urllib.error.URLError('offline') + + monkeypatch.setattr(checkup.urllib.request, 'urlopen', fake_urlopen) + + assert checkup.pypi_api_fetch('mycli/json') == {} + assert 'Failed to connect to PyPi on https://pypi.org/pypi/mycli/json' in capsys.readouterr().err + + +def test_dependencies_checkup(monkeypatch, capsys) -> None: + versions = { + 'cli_helpers': '1.0.0', + 'click': '2.0.0', + 'prompt_toolkit': '3.0.0', + 'pymysql': '4.0.0', + } + + def fake_version(name: str) -> str: + if name == 'tabulate': + raise importlib.metadata.PackageNotFoundError + return versions[name] + + def fake_pypi_api_fetch(fragment: str) -> dict: + dependency = fragment.strip('/').removesuffix('/json') + return {'info': {'version': f'latest-{dependency}'}} + + monkeypatch.setattr(checkup.importlib.metadata, 'version', fake_version) + monkeypatch.setattr(checkup, 'pypi_api_fetch', fake_pypi_api_fetch) + + checkup._dependencies_checkup() + output = capsys.readouterr().out + + assert '### Key Python dependencies:' in output + assert 'cli_helpers version 1.0.0 (latest latest-cli_helpers)' in output + assert 'click version 2.0.0 (latest latest-click)' in output + assert 'prompt_toolkit version 3.0.0 (latest latest-prompt_toolkit)' in output + assert 'pymysql version 4.0.0 (latest latest-pymysql)' in output + assert 'tabulate version None (latest latest-tabulate)' in output + + +def test_executables_checkup(monkeypatch, capsys) -> None: + monkeypatch.setattr( + checkup.shutil, + 'which', + lambda executable: f'/usr/bin/{executable}' if executable != 'fzf' else None, + ) + + checkup._executables_checkup() + output = capsys.readouterr().out + + assert '### External executables:' in output + assert 'The "less" executable was found' in output + assert 'The recommended "fzf" executable was not found' in output + assert 'The "pygmentize" executable was found' in output + + +def test_environment_checkup(monkeypatch, capsys) -> None: + monkeypatch.setenv('EDITOR', 'vim') + monkeypatch.delenv('VISUAL', raising=False) + + checkup._environment_checkup() + output = capsys.readouterr().out + + assert '### Environment variables:' in output + assert 'The $EDITOR environment variable was set to "vim" ' in output + assert 'The $VISUAL environment variable was not set' in output + + +def test_configuration_checkup_missing_file(capsys) -> None: + mycli = SimpleNamespace( + config={}, + config_without_package_defaults={}, + config_without_user_options={}, + ) + + checkup._configuration_checkup(mycli) + output = capsys.readouterr().out + + assert '### Missing file:' in output + assert 'The local ~/,myclirc is missing or empty.' in output + assert f'{checkup.REPO_URL}/blob/main/mycli/myclirc' in output + + +def test_configuration_checkup_reports_missing_unsupported_and_deprecated(capsys) -> None: + mycli = SimpleNamespace( + config={ + 'main': { + 'present': '', + 'missing_item': '', + }, + 'extra_section': { + 'extra_item': '', + }, + }, + config_without_package_defaults={ + 'main': { + 'present': '', + 'unsupported_item': '', + 'default_character_set': '', + }, + 'unsupported_section': { + 'anything': '', + }, + 'colors': { + 'sql.keyword': '', + }, + 'favorite_queries': { + 'demo': 'select 1', + }, + }, + config_without_user_options={ + 'main': { + 'present': '', + }, + 'colors': {}, + }, + ) + + checkup._configuration_checkup(mycli) + output = capsys.readouterr().out + + assert '### Missing in user ~/.myclirc:' in output + assert 'The entire section:\n\n [extra_section]\n' in output + assert 'The item:\n\n [main]\n missing_item =' in output + assert '### Unsupported in user ~/.myclirc:' in output + assert 'The entire section:\n\n [unsupported_section]\n' in output + assert 'The item:\n\n [main]\n unsupported_item =' in output + assert '### Deprecated in user ~/.myclirc:' in output + assert ' [main]\n default_character_set' in output + assert ' [connection]\n default_character_set' in output + assert f'{checkup.REPO_URL}/blob/main/mycli/myclirc' in output + + +def test_configuration_checkup_skips_transitioned_and_free_entry_items(capsys) -> None: + mycli = SimpleNamespace( + config={ + 'extra_section': { + 'extra_item': '', + }, + 'connection': { + 'default_character_set': '', + }, + }, + config_without_package_defaults={ + 'connection': {}, + 'unsupported_section': { + 'anything': '', + }, + 'favorite_queries': { + 'demo': 'select 1', + }, + }, + config_without_user_options={ + 'connection': {}, + 'favorite_queries': {}, + }, + ) + + checkup._configuration_checkup(mycli) + output = capsys.readouterr().out + + assert 'Missing in user ~/.myclirc:' in output + assert 'The entire section:\n\n [extra_section]\n' in output + assert 'Unsupported in user ~/.myclirc:' in output + assert 'The entire section:\n\n [unsupported_section]\n' in output + assert '[connection]\n default_character_set =' not in output + assert '[favorite_queries]' not in output + + +def test_configuration_checkup_up_to_date(capsys) -> None: + mycli = SimpleNamespace( + config={ + 'main': { + 'prompt': '', + }, + }, + config_without_package_defaults={ + 'main': { + 'prompt': '', + }, + }, + config_without_user_options={ + 'main': { + 'prompt': '', + }, + }, + ) + + checkup._configuration_checkup(mycli) + output = capsys.readouterr().out + + assert '### Configuration:' in output + assert 'User configuration all up to date!' in output + + +def test_main_checkup_calls_all_sections(monkeypatch) -> None: + calls: list[tuple[str, object]] = [] + mycli = SimpleNamespace(name='mycli') + + monkeypatch.setattr(checkup, '_dependencies_checkup', lambda: calls.append(('dependencies', None))) + monkeypatch.setattr(checkup, '_executables_checkup', lambda: calls.append(('executables', None))) + monkeypatch.setattr(checkup, '_environment_checkup', lambda: calls.append(('environment', None))) + monkeypatch.setattr(checkup, '_configuration_checkup', lambda arg: calls.append(('configuration', arg))) + + checkup.main_checkup(mycli) + + assert calls == [ + ('dependencies', None), + ('executables', None), + ('environment', None), + ('configuration', mycli), + ] diff --git a/test/pytests/test_cli_args.py b/test/pytests/test_cli_args.py new file mode 100644 index 00000000..f9171bdc --- /dev/null +++ b/test/pytests/test_cli_args.py @@ -0,0 +1,175 @@ +from __future__ import annotations + +import builtins +from pathlib import Path +from typing import Any + +import click +import pytest + +from mycli import cli_args as cli_args_module +from mycli.cli_args import ( + EMPTY_PASSWORD_FLAG_SENTINEL, + INT_OR_STRING_CLICK_TYPE, + CliArgs, + get_password_from_file, + preprocess_cli_args, +) + + +def valid_connection_scheme(value: str) -> tuple[bool, str | None]: + scheme, _, _ = value.partition('://') + return scheme == 'mysql', scheme or None + + +def test_int_or_string_click_type_accepts_int_string_and_none() -> None: + assert INT_OR_STRING_CLICK_TYPE.convert(7, None, None) == 7 + assert INT_OR_STRING_CLICK_TYPE.convert('secret', None, None) == 'secret' + assert INT_OR_STRING_CLICK_TYPE.convert(None, None, None) is None + + +def test_int_or_string_click_type_rejects_other_values() -> None: + with pytest.raises(click.BadParameter, match='Not a valid password string'): + INT_OR_STRING_CLICK_TYPE.convert(object(), None, None) + + +def test_get_password_from_file_reads_first_line_without_trailing_newline(tmp_path: Path) -> None: + password_file = tmp_path / 'password.txt' + password_file.write_text('secret\nignored\n', encoding='utf8') + + assert get_password_from_file(str(password_file)) == 'secret' + + +def test_get_password_from_file_returns_none_for_missing_path() -> None: + assert get_password_from_file(None) is None + assert get_password_from_file('') is None + + +@pytest.mark.parametrize( + ('exception', 'expected'), + [ + (FileNotFoundError(), "Password file 'secret.txt' not found"), + (PermissionError(), "Permission denied reading password file 'secret.txt'"), + (IsADirectoryError(), "Path 'secret.txt' is a directory, not a file"), + (RuntimeError('boom'), "Error reading password file 'secret.txt': boom"), + ], +) +def test_get_password_from_file_exits_with_error_for_read_failures( + monkeypatch: pytest.MonkeyPatch, + capsys: pytest.CaptureFixture[str], + exception: Exception, + expected: str, +) -> None: + def raise_error(*_args: Any, **_kwargs: Any) -> None: + raise exception + + monkeypatch.setattr(builtins, 'open', raise_error) + + with pytest.raises(SystemExit) as excinfo: + get_password_from_file('secret.txt') + + assert excinfo.value.code == 1 + assert expected in capsys.readouterr().err + + +def test_preprocess_cli_args_moves_dsn_from_password_to_database() -> None: + cli_args = CliArgs() + cli_args.password = 'mysql://user:pass@host/db' + + verbosity = preprocess_cli_args(cli_args, valid_connection_scheme) + + assert verbosity == 0 + assert cli_args.database == 'mysql://user:pass@host/db' + assert cli_args.password == EMPTY_PASSWORD_FLAG_SENTINEL # type: ignore[comparison-overlap] + + +def test_preprocess_cli_args_rejects_unknown_dsn_scheme(capsys: pytest.CaptureFixture[str]) -> None: + cli_args = CliArgs() + cli_args.password = 'postgres://user:pass@host/db' + + with pytest.raises(SystemExit) as excinfo: + preprocess_cli_args(cli_args, valid_connection_scheme) + + assert excinfo.value.code == 1 + assert 'Unknown connection scheme provided for DSN URI (postgres://)' in capsys.readouterr().err + + +def test_preprocess_cli_args_reads_password_file_when_password_missing( + monkeypatch: pytest.MonkeyPatch, +) -> None: + cli_args = CliArgs() + cli_args.password_file = 'secret.txt' + monkeypatch.setattr(cli_args_module, 'get_password_from_file', lambda password_file: f'from:{password_file}') + + assert preprocess_cli_args(cli_args, valid_connection_scheme) == 0 + assert cli_args.password == 'from:secret.txt' + + +def test_preprocess_cli_args_uses_mysql_pwd_when_password_and_file_missing(monkeypatch: pytest.MonkeyPatch) -> None: + cli_args = CliArgs() + monkeypatch.setenv('MYSQL_PWD', 'env-secret') + + assert preprocess_cli_args(cli_args, valid_connection_scheme) == 0 + assert cli_args.password == 'env-secret' + + +def test_preprocess_cli_args_prefers_existing_password_over_mysql_pwd(monkeypatch: pytest.MonkeyPatch) -> None: + cli_args = CliArgs() + cli_args.password = 'cli-secret' + monkeypatch.setenv('MYSQL_PWD', 'env-secret') + + assert preprocess_cli_args(cli_args, valid_connection_scheme) == 0 + assert cli_args.password == 'cli-secret' + + +@pytest.mark.parametrize( + ('checkpoint', 'batch', 'expected'), + [ + (None, 'batch.sql', 'Error: --resume requires a --checkpoint file.'), + (object(), None, 'Error: --resume requires a --batch file.'), + ], +) +def test_preprocess_cli_args_validates_resume_requirements( + capsys: pytest.CaptureFixture[str], + checkpoint: object | None, + batch: str | None, + expected: str, +) -> None: + cli_args = CliArgs() + cli_args.resume = True + cli_args.checkpoint = checkpoint # type: ignore[assignment] + cli_args.batch = batch + + with pytest.raises(SystemExit) as excinfo: + preprocess_cli_args(cli_args, valid_connection_scheme) + + assert excinfo.value.code == 1 + assert expected in capsys.readouterr().err + + +def test_preprocess_cli_args_rejects_verbose_and_quiet(capsys: pytest.CaptureFixture[str]) -> None: + cli_args = CliArgs() + cli_args.verbose = 1 + cli_args.quiet = True + + with pytest.raises(SystemExit) as excinfo: + preprocess_cli_args(cli_args, valid_connection_scheme) + + assert excinfo.value.code == 1 + assert 'Error: --verbose and --quiet are incompatible.' in capsys.readouterr().err + + +@pytest.mark.parametrize( + ('verbose', 'quiet', 'expected'), + [ + (2, False, 2), + (0, True, -1), + (0, False, 0), + ], +) +def test_preprocess_cli_args_returns_cli_verbosity(verbose: int, quiet: bool, expected: int) -> None: + cli_args = CliArgs() + cli_args.verbose = verbose + cli_args.quiet = quiet + + assert preprocess_cli_args(cli_args, valid_connection_scheme) == expected diff --git a/test/pytests/test_cli_utils.py b/test/pytests/test_cli_utils.py new file mode 100644 index 00000000..1d01d3e6 --- /dev/null +++ b/test/pytests/test_cli_utils.py @@ -0,0 +1,39 @@ +# type: ignore + +import pytest + +from mycli.packages import cli_utils +from mycli.packages.cli_utils import ( + filtered_sys_argv, + is_valid_connection_scheme, +) + + +@pytest.mark.parametrize( + ('argv', 'expected'), + [ + (['mycli', '-h'], ['--help']), + (['mycli', '-h', 'example.com'], ['-h', 'example.com']), + ], +) +def test_filtered_sys_argv(monkeypatch, argv, expected): + monkeypatch.setattr(cli_utils.sys, 'argv', argv) + + assert filtered_sys_argv() == expected + + +@pytest.mark.parametrize( + ('text', 'is_valid', 'invalid_scheme'), + [ + ('localhost', False, None), + ('mysql://user@localhost/db', True, None), + ('mysqlx://user@localhost/db', True, None), + ('tcp://localhost:3306', True, None), + ('socket:///tmp/mysql.sock', True, None), + ('ssh://user@example.com', True, None), + ('postgres://user@localhost/db', False, 'postgres'), + ('http://example.com', False, 'http'), + ], +) +def test_is_valid_connection_scheme(text, is_valid, invalid_scheme): + assert is_valid_connection_scheme(text) == (is_valid, invalid_scheme) diff --git a/test/pytests/test_clibuffer.py b/test/pytests/test_clibuffer.py new file mode 100644 index 00000000..d502e009 --- /dev/null +++ b/test/pytests/test_clibuffer.py @@ -0,0 +1,115 @@ +from dataclasses import dataclass +from types import SimpleNamespace + +import pytest + +from mycli import clibuffer + + +@dataclass +class DummyDocument: + text: str + + +@dataclass +class DummyBuffer: + document: DummyDocument + + +@dataclass +class DummyLayout: + buffer: DummyBuffer + requested_names: list[str] + + def get_buffer_by_name(self, name: str) -> DummyBuffer: + self.requested_names.append(name) + return self.buffer + + +def make_app_for_text(text: str) -> tuple[SimpleNamespace, DummyLayout]: + layout = DummyLayout( + buffer=DummyBuffer(document=DummyDocument(text=text)), + requested_names=[], + ) + return SimpleNamespace(layout=layout), layout + + +def test_multiline_exception_handles_favorite_queries_only_after_blank_line() -> None: + assert clibuffer._multiline_exception(r'\fs demo select 1; select 2') is False + assert clibuffer._multiline_exception('\\fs demo select 1; select 2\n') is True + + +@pytest.mark.parametrize( + ('text', 'expected'), + ( + (r'\dt', True), + ('select 1 //', True), + ('select 1 \\g', True), + ('select 1 \\G', True), + ('select 1 \\e', True), + ('select 1 \\edit', True), + ('select 1 \\clip', True), + ('help topic', True), + ('HELP topic', True), + (' ', True), + ('select 1', False), + ), +) +def test_multiline_exception_detects_commands_terminators_and_plain_sql( + monkeypatch, + text: str, + expected: bool, +) -> None: + monkeypatch.setattr(clibuffer.iocommands, 'get_current_delimiter', lambda: '//') + monkeypatch.setattr(clibuffer, 'SPECIAL_COMMANDS', {'help': object(), 'exit': object()}) + + assert clibuffer._multiline_exception(text) is expected + + +def test_cli_is_multiline_returns_false_when_multiline_mode_is_disabled(monkeypatch) -> None: + mycli = SimpleNamespace(multi_line=False) + + def fail_get_app() -> None: + raise AssertionError('get_app() should not be called when multiline mode is disabled') + + monkeypatch.setattr(clibuffer, 'get_app', fail_get_app) + + multiline_filter = clibuffer.cli_is_multiline(mycli) + + assert multiline_filter() is False + + +@pytest.mark.parametrize('text', ('help\tselect', 'HELP\nselect')) +def test_multiline_exception_recognizes_non_backslashed_special_commands_with_general_whitespace( + monkeypatch, + text: str, +) -> None: + monkeypatch.setattr(clibuffer.iocommands, 'get_current_delimiter', lambda: ';') + monkeypatch.setattr(clibuffer, 'SPECIAL_COMMANDS', {'help': object(), 'exit': object()}) + + assert clibuffer._multiline_exception(text) is True + + +@pytest.mark.parametrize( + ('text', 'expected'), + ( + ('select 1', True), + ('help select', False), + ), +) +def test_cli_is_multiline_uses_buffer_text_when_multiline_mode_is_enabled( + monkeypatch, + text: str, + expected: bool, +) -> None: + app, layout = make_app_for_text(text) + mycli = SimpleNamespace(multi_line=True) + + monkeypatch.setattr(clibuffer, 'get_app', lambda: app) + monkeypatch.setattr(clibuffer.iocommands, 'get_current_delimiter', lambda: ';') + monkeypatch.setattr(clibuffer, 'SPECIAL_COMMANDS', {'help': object()}) + + multiline_filter = clibuffer.cli_is_multiline(mycli) + + assert multiline_filter() is expected + assert layout.requested_names == [clibuffer.DEFAULT_BUFFER] diff --git a/test/pytests/test_clistyle.py b/test/pytests/test_clistyle.py new file mode 100644 index 00000000..3e152c9f --- /dev/null +++ b/test/pytests/test_clistyle.py @@ -0,0 +1,191 @@ +# type: ignore + +"""Tests for the mycli.clistyle module.""" + +from types import SimpleNamespace + +from prompt_toolkit.styles import Style as PromptStyle +from pygments.style import Style as PygmentsStyle +from pygments.token import Token +from pygments.util import ClassNotFound + +from mycli import clistyle + + +def test_parse_pygments_style_handles_style_classes_instances_and_dict_values() -> None: + class DemoStyle(PygmentsStyle): + default_style = '' + styles = { + Token.Name: 'bold', + Token.String: 'ansired', + } + + token_type, style_value = clistyle.parse_pygments_style( + 'Token.String', + DemoStyle, + {'Token.String': 'Token.Name'}, + ) + assert token_type == Token.String + assert style_value == 'bold' + + token_type, style_value = clistyle.parse_pygments_style( + 'Token.String', + DemoStyle(), + {'Token.String': 'Token.Name'}, + ) + assert token_type == Token.String + assert style_value == 'bold' + + token_type, style_value = clistyle.parse_pygments_style( + 'Token.String', + 'unused', + {'Token.String': 'ansiblue'}, + ) + assert token_type == Token.String + assert style_value == 'ansiblue' + + +def test_is_valid_pygments_returns_true_and_false(monkeypatch) -> None: + assert clistyle.is_valid_pygments('ansired') is True + + class FailingPygmentsStyle: + def __init_subclass__(cls, **kwargs) -> None: + raise AssertionError('bad style') + + monkeypatch.setattr(clistyle, 'PygmentsStyle', FailingPygmentsStyle) + + assert clistyle.is_valid_pygments('invalid') is False + + +def test_is_valid_ptoolkit_returns_true_and_false(monkeypatch) -> None: + assert clistyle.is_valid_ptoolkit('bold') is True + + class FailingPromptStyle: + def __init__(self, _rules) -> None: + raise ValueError('bad style') + + monkeypatch.setattr(clistyle, 'Style', FailingPromptStyle) + + assert clistyle.is_valid_ptoolkit('invalid') is False + + +def test_style_factory_ptoolkit_builds_styles_and_falls_back(monkeypatch, caplog) -> None: + calls: list[str] = [] + native_style = object() + + def fake_get_style_by_name(name: str): + calls.append(name) + if name == 'missing': + raise ClassNotFound('missing') + if name == 'native': + return native_style + raise AssertionError(f'unexpected style {name}') + + class FakeStyle: + def __init__(self, rules) -> None: + self.rules = list(rules) + + monkeypatch.setattr(clistyle.pygments.styles, 'get_style_by_name', fake_get_style_by_name) + monkeypatch.setattr( + clistyle, + 'parse_pygments_style', + lambda token, style, cli_style: { + 'Token.Prompt': (Token.Prompt, 'token-valid'), + 'Token.Toolbar': (Token.Toolbar, 'token-invalid'), + 'Token.Name': (Token.Name, 'token-invalid'), + }[token], + ) + monkeypatch.setattr(clistyle, 'is_valid_ptoolkit', lambda value: value in {'token-valid', 'prompt-valid'}) + monkeypatch.setattr(clistyle, 'Style', FakeStyle) + monkeypatch.setattr(clistyle, 'style_from_pygments_cls', lambda style: ('pygments-style', style)) + monkeypatch.setattr(clistyle, 'merge_styles', lambda styles: styles) + + cli_style = { + 'Token.Prompt': 'Token.Name', + 'Token.Toolbar': 'Token.Name', + 'Token.Name': 'ignored', + 'prompt': 'prompt-valid', + 'search': 'prompt-invalid', + } + + with caplog.at_level('ERROR', logger='mycli.clistyle'): + styles = clistyle.style_factory_ptoolkit('missing', cli_style) + + assert calls == ['missing', 'native'] + assert styles[0] == ('pygments-style', native_style) + assert styles[1].rules == [('bottom-toolbar', 'noreverse')] + assert styles[2].rules == [ + ('prompt', 'token-valid'), + ('prompt', 'prompt-valid'), + ] + assert ('bottom-toolbar', 'token-invalid') not in styles[2].rules + assert ('search', 'prompt-invalid') not in styles[2].rules + assert 'Unhandled style / class name: Token.Name' in caplog.text + + +def test_style_factory_helpers_updates_known_tokens(monkeypatch, caplog) -> None: + base_styles = {Token.Output.Header: 'ansiyellow'} + style_class = SimpleNamespace(styles=base_styles) + + monkeypatch.setattr(clistyle.pygments.styles, 'get_style_by_name', lambda name: style_class) + monkeypatch.setattr( + clistyle, + 'parse_pygments_style', + lambda token, style, cli_style: { + 'Token.Prompt': (Token.Prompt, 'ansiblue'), + 'Token.Toolbar': (Token.Toolbar, 'skip-me'), + }[token], + ) + monkeypatch.setattr(clistyle, 'is_valid_pygments', lambda value: value != 'skip-me') + + cli_style = { + 'Token.Prompt': 'Token.Name', + 'Token.Toolbar': 'Token.Name', + 'search': 'ansigreen', + 'search.current': 'skip-me', + 'sql.keyword': 'ansired', + 'sql.string': 'skip-me', + 'unknown': 'skip-me', + } + + with caplog.at_level('ERROR', logger='mycli.clistyle'): + output_style = clistyle.style_factory_helpers('native', cli_style) + + assert output_style.styles[Token.Prompt] == 'ansiblue' + assert output_style.styles[Token.SearchMatch] == 'ansigreen' + assert Token.SearchMatch.Current not in output_style.styles + assert output_style.styles[Token.Keyword] == 'ansired' + assert output_style.styles[Token.Output.Header] == 'ansiyellow' + assert Token.Toolbar not in output_style.styles + assert output_style.styles[Token.String] != 'skip-me' + assert 'Unhandled style / class name: unknown' in caplog.text + + +def test_style_factory_helpers_falls_back_and_copies_warning_styles(monkeypatch) -> None: + native_styles = { + Token.Text: 'ansiblack', + Token.Warnings.Header: 'ansimagenta', + Token.Warnings.Status: 'ansicyan', + } + + def fake_get_style_by_name(name: str): + if name == 'missing': + raise ClassNotFound('missing') + if name == 'native': + return SimpleNamespace(styles=native_styles.copy()) + raise AssertionError(f'unexpected style {name}') + + monkeypatch.setattr(clistyle.pygments.styles, 'get_style_by_name', fake_get_style_by_name) + + output_style = clistyle.style_factory_helpers('missing', {}, warnings=True) + + assert output_style.styles[Token.Warnings.Header] == 'ansimagenta' + assert output_style.styles[Token.Warnings.Status] == 'ansicyan' + assert output_style.styles[Token.Output.Header] == 'ansimagenta' + assert output_style.styles[Token.Output.Status] == 'ansicyan' + + +def test_style_factory_ptoolkit_returns_merged_style_object() -> None: + style = clistyle.style_factory_ptoolkit('native', {'prompt': 'bold'}) + + assert style.get_attrs_for_style_str('class:prompt') == PromptStyle([('prompt', 'bold')]).get_attrs_for_style_str('class:prompt') diff --git a/test/pytests/test_clitoolbar.py b/test/pytests/test_clitoolbar.py new file mode 100644 index 00000000..d0ffc104 --- /dev/null +++ b/test/pytests/test_clitoolbar.py @@ -0,0 +1,143 @@ +# type: ignore + +from types import SimpleNamespace +from unittest.mock import MagicMock + +from prompt_toolkit.enums import EditingMode +from prompt_toolkit.key_binding.vi_state import InputMode +import pytest + +from mycli import clitoolbar + + +def make_mycli( + *, + smart_completion: bool = True, + multi_line: bool = False, + editing_mode: EditingMode = EditingMode.EMACS, + toolbar_error_message: str | None = None, + refreshing: bool = False, + prefetching: bool = False, +): + return SimpleNamespace( + completer=SimpleNamespace(smart_completion=smart_completion), + multi_line=multi_line, + prompt_session=SimpleNamespace(editing_mode=editing_mode), + toolbar_error_message=toolbar_error_message, + completion_refresher=SimpleNamespace(is_refreshing=MagicMock(return_value=refreshing)), + schema_prefetcher=SimpleNamespace(is_prefetching=MagicMock(return_value=prefetching)), + get_custom_toolbar=MagicMock(return_value="custom toolbar"), + ) + + +def test_create_toolbar_tokens_func_shows_initial_help() -> None: + mycli = make_mycli() + + toolbar = clitoolbar.create_toolbar_tokens_func(mycli, lambda: True, None, mycli.get_custom_toolbar) + result = toolbar() + + assert ("class:bottom-toolbar", "right-arrow accepts full-line suggestion") in result + assert ("class:bottom-toolbar", "[F2] Smart-complete:") in result + assert ("class:bottom-toolbar.on", "ON ") in result + assert ("class:bottom-toolbar", "[F3] Multiline:") in result + assert ("class:bottom-toolbar.off", "OFF") in result + + +def test_create_toolbar_tokens_func_clears_toolbar_error_message() -> None: + mycli = make_mycli(toolbar_error_message="boom") + + toolbar = clitoolbar.create_toolbar_tokens_func(mycli, lambda: False, None, mycli.get_custom_toolbar) + first = toolbar() + second = toolbar() + + assert ("class:bottom-toolbar.transaction.failed", "boom") in first + assert ("class:bottom-toolbar.transaction.failed", "boom") not in second + assert mycli.toolbar_error_message is None + assert ("class:bottom-toolbar", "right-arrow accepts full-line suggestion") not in first + + +def test_create_toolbar_tokens_func_shows_prefetching() -> None: + mycli = make_mycli(prefetching=True) + + toolbar = clitoolbar.create_toolbar_tokens_func(mycli, lambda: False, None, mycli.get_custom_toolbar) + result = toolbar() + + assert ("class:bottom-toolbar", "Prefetching schemas…") in result + + +def test_create_toolbar_tokens_func_shows_multiline_vi_and_refreshing(monkeypatch) -> None: + mycli = make_mycli( + smart_completion=False, + multi_line=True, + editing_mode=EditingMode.VI, + refreshing=True, + ) + monkeypatch.setattr(clitoolbar.special, 'get_current_delimiter', lambda: '$$') + monkeypatch.setattr(clitoolbar, '_get_vi_mode', lambda: 'N') + + toolbar = clitoolbar.create_toolbar_tokens_func(mycli, lambda: False, None, mycli.get_custom_toolbar) + result = toolbar() + + assert ("class:bottom-toolbar.off", "OFF") in result + assert ("class:bottom-toolbar", "[F3] Multiline:") in result + assert ("class:bottom-toolbar.on", "ON ") in result + assert ("class:bottom-toolbar", "Vi:") in result + assert ("class:bottom-toolbar.on", "N") in result + assert ('class:bottom-toolbar.on', '$$') in result + assert ("class:bottom-toolbar", "Refreshing completions…") in result + + +def test_create_toolbar_tokens_func_applies_custom_format(monkeypatch) -> None: + mycli = make_mycli(multi_line=True, refreshing=True) + monkeypatch.setattr(clitoolbar.special, 'get_current_delimiter', lambda: '$$') + + formatted = [("class:bottom-toolbar", "CUSTOM")] + to_formatted_text = MagicMock(return_value=formatted) + monkeypatch.setattr(clitoolbar, 'to_formatted_text', to_formatted_text) + + toolbar = clitoolbar.create_toolbar_tokens_func(mycli, lambda: True, r'\Bfmt', mycli.get_custom_toolbar) + result = toolbar() + + mycli.get_custom_toolbar.assert_called_once_with('fmt') + to_formatted_text.assert_called_once_with("custom toolbar", style='class:bottom-toolbar') + assert ('class:bottom-toolbar', '\n') in result + assert ("class:bottom-toolbar", "CUSTOM") in result + assert ("class:bottom-toolbar", "right-arrow accepts full-line suggestion") in result + assert ("class:bottom-toolbar", "Refreshing completions…") in result + + +def test_create_toolbar_tokens_func_replaces_default_toolbar_for_plain_custom_format(monkeypatch) -> None: + mycli = make_mycli(multi_line=True, toolbar_error_message='boom', refreshing=True) + monkeypatch.setattr(clitoolbar.special, 'get_current_delimiter', lambda: '$$') + + formatted = [('class:bottom-toolbar', 'PLAIN CUSTOM')] + to_formatted_text = MagicMock(return_value=formatted) + monkeypatch.setattr(clitoolbar, 'to_formatted_text', to_formatted_text) + + toolbar = clitoolbar.create_toolbar_tokens_func(mycli, lambda: True, 'fmt', mycli.get_custom_toolbar) + result = toolbar() + + mycli.get_custom_toolbar.assert_called_once_with('fmt') + to_formatted_text.assert_called_once_with('custom toolbar', style='class:bottom-toolbar') + assert ('class:bottom-toolbar', 'PLAIN CUSTOM') in result + assert ('class:bottom-toolbar', '[Tab] Complete') not in result + assert ('class:bottom-toolbar', '[F1] Help') not in result + assert ('class:bottom-toolbar', 'right-arrow accepts full-line suggestion') in result + assert ('class:bottom-toolbar.transaction.failed', 'boom') in result + + +@pytest.mark.parametrize( + ('input_mode', 'expected'), + [ + (InputMode.INSERT, 'I'), + (InputMode.NAVIGATION, 'N'), + (InputMode.REPLACE, 'R'), + (InputMode.REPLACE_SINGLE, 'R'), + (InputMode.INSERT_MULTIPLE, 'M'), + ], +) +def test_get_vi_mode(monkeypatch, input_mode: InputMode, expected: str) -> None: + app = SimpleNamespace(vi_state=SimpleNamespace(input_mode=input_mode)) + monkeypatch.setattr(clitoolbar, 'get_app', lambda: app) + + assert clitoolbar._get_vi_mode() == expected diff --git a/test/pytests/test_completion_engine.py b/test/pytests/test_completion_engine.py new file mode 100644 index 00000000..e6b4bc89 --- /dev/null +++ b/test/pytests/test_completion_engine.py @@ -0,0 +1,1783 @@ +# type: ignore + +from types import SimpleNamespace + +import pytest +import sqlparse + +from mycli.packages import completion_engine, special +from mycli.packages.completion_engine import ( + _aliases, + _build_suggest_context, + _charset_suggestion, + _emit_binary_or_comma, + _emit_blank_token, + _emit_character_set, + _emit_collation, + _emit_column_for_tables, + _emit_database, + _emit_lparen, + _emit_none_token, + _emit_nothing, + _emit_on, + _emit_procedure, + _emit_relation_like, + _emit_relation_name, + _emit_select_like, + _emit_show, + _emit_star, + _emit_to, + _emit_user, + _emit_where_token, + _enum_value_suggestion, + _find_doubled_backticks, + _is_single_or_double_quoted, + _is_where_or_having, + _keyword_and_special_suggestions, + _keyword_suggestions, + _normalize_token_value, + _parent_name, + _parse_suggestion_statement, + _tables, + _token_is_binary_or_comma, + _token_is_blank, + _token_is_lparen, + _token_is_none, + _token_is_relation_keyword, + _token_value_is, + _tokens_wo_space, + _word_starts_with_digit_or_dot, + _word_starts_with_quote, + identifies, + is_inside_quotes, + suggest_based_on_last_token, + suggest_special, + suggest_type, +) + + +def sorted_dicts(dicts): + """input is a list of dicts.""" + return sorted(tuple(x.items()) for x in dicts) + + +def flattened_tokens(text): + return list(sqlparse.parse(text)[0].flatten()) + + +def value_tokens(*values): + return [SimpleNamespace(value=value) for value in values] + + +def empty_identifier(): + return SimpleNamespace(get_parent_name=lambda: None) + + +def last_non_whitespace_token(text): + parsed = sqlparse.parse(text)[0] + return parsed.token_prev(len(parsed.tokens) - 1)[1] + + +def test_select_suggests_cols_with_visible_table_scope(): + suggestions = suggest_type("SELECT FROM tabl", "SELECT ") + assert sorted_dicts(suggestions) == sorted_dicts([ + {"type": "alias", "aliases": ["tabl"]}, + {"type": "column", "tables": [(None, "tabl", None)]}, + {"type": "function", "schema": []}, + {"type": "introducer"}, + ]) + + +def test_select_suggests_cols_with_qualified_table_scope(): + suggestions = suggest_type("SELECT FROM sch.tabl", "SELECT ") + assert sorted_dicts(suggestions) == sorted_dicts([ + {"type": "alias", "aliases": ["tabl"]}, + {"type": "column", "tables": [("sch", "tabl", None)]}, + {"type": "function", "schema": []}, + {"type": "introducer"}, + ]) + + +@pytest.mark.parametrize( + "expression", + [ + "SELECT * FROM tabl WHERE ", + "SELECT * FROM tabl WHERE (", + "SELECT * FROM tabl WHERE bar OR ", + "SELECT * FROM tabl WHERE foo = 1 AND ", + "SELECT * FROM tabl WHERE (bar > 10 AND ", + "SELECT * FROM tabl WHERE (bar AND (baz OR (qux AND (", + "SELECT * FROM tabl WHERE 10 < ", + "SELECT * FROM tabl WHERE foo BETWEEN ", + "SELECT * FROM tabl WHERE foo BETWEEN foo AND ", + ], +) +def test_where_suggests_columns_functions(expression): + suggestions = suggest_type(expression, expression) + assert sorted_dicts(suggestions) == sorted_dicts([ + {"type": "alias", "aliases": ["tabl"]}, + {"type": "column", "tables": [(None, "tabl", None)]}, + {"type": "function", "schema": []}, + {"type": "introducer"}, + ]) + + +def test_where_equals_suggests_enum_values_first(): + expression = "SELECT * FROM tabl WHERE foo = " + suggestions = suggest_type(expression, expression) + assert sorted_dicts(suggestions) == sorted_dicts([ + {"type": "enum_value", "tables": [(None, "tabl", None)], "column": "foo", "parent": None}, + {"type": "alias", "aliases": ["tabl"]}, + {"type": "column", "tables": [(None, "tabl", None)]}, + {"type": "function", "schema": []}, + {"type": "introducer"}, + ]) + + +def test_enum_value_suggestion_returns_none_without_equals_context(): + expression = 'SELECT * FROM tabl WHERE foo' + suggestion = _enum_value_suggestion(expression, expression) + assert suggestion is None + + +def test_enum_value_suggestion_returns_column_and_tables(): + expression = 'SELECT * FROM tabl WHERE foo = ' + suggestion = _enum_value_suggestion(expression, expression) + assert suggestion == { + 'type': 'enum_value', + 'tables': [(None, 'tabl', None)], + 'column': 'foo', + 'parent': None, + } + + +def test_enum_value_suggestion_handles_qualified_backticked_identifier(): + expression = 'SELECT * FROM sch.tabl WHERE `tabl`.`foo` = ' + suggestion = _enum_value_suggestion(expression, expression) + assert suggestion == { + 'type': 'enum_value', + 'tables': [('sch', 'tabl', None)], + 'column': '`foo`', + 'parent': '`tabl`', + } + + +def test_enum_value_suggestion_returns_none_inside_quotes(): + full_text = 'SELECT * FROM tabl WHERE "foo = ' + text_before_cursor = 'SELECT * FROM tabl WHERE "foo = ' + suggestion = _enum_value_suggestion(text_before_cursor, full_text) + assert suggestion is None + + +@pytest.mark.parametrize( + ('tokens', 'expected'), + [ + (value_tokens('character', 'set'), [{'type': 'character_set'}]), + (value_tokens('x', 'character', 'set', ' '), [{'type': 'character_set'}]), + (value_tokens('collate'), [{'type': 'collation'}]), + (value_tokens('select', 'foo'), None), + ], +) +def test_charset_suggestion(tokens, expected): + assert _charset_suggestion(tokens) == expected + + +def test_keyword_suggestions(): + assert _keyword_suggestions() == [{'type': 'keyword'}] + + +def test_keyword_and_special_suggestions(): + assert _keyword_and_special_suggestions() == [{'type': 'keyword'}, {'type': 'special'}] + + +def test_parse_suggestion_statement_returns_statement_and_nonspace_tokens(): + tokens_wo_space = _tokens_wo_space('select 1') + assert [token.value for token in tokens_wo_space] == ['select', '1'] + + +def test_parse_suggestion_statement_raises_type_error_for_invalid_input_type(): + with pytest.raises(TypeError): + _parse_suggestion_statement(None) # type: ignore[arg-type] + + +def test_normalize_token_value_handles_string(): + assert _normalize_token_value('SELECT') == 'select' + + +def test_normalize_token_value_handles_none(): + assert _normalize_token_value(None) is None + + +def test_normalize_token_value_handles_plain_token(): + token = SimpleNamespace(value='SHOW') + assert _normalize_token_value(token) == 'show' + + +def test_normalize_token_value_handles_comparison_token(): + comparison = sqlparse.parse('a.id = d.')[0].tokens[0] + assert _normalize_token_value(comparison) == 'd.' + + +def test_build_suggest_context_populates_fields(): + identifier = empty_identifier() + context = _build_suggest_context( + 'SHOW', + 'show ', + None, + 'show ', + identifier, + ) + + assert context.token == 'SHOW' + assert context.token_value == 'show' + assert context.text_before_cursor == 'show ' + assert context.word_before_cursor is None + assert context.full_text == 'show ' + assert context.identifier is identifier + assert str(context.parsed_cb()) == 'show ' + assert [token.value for token in context.tokens_wo_space_cb()] == ['show'] + + +def test_build_suggest_context_handles_none_token(): + context = _build_suggest_context( + None, + '', + None, + '', + empty_identifier(), + ) + + assert context.token is None + assert context.token_value is None + assert str(context.parsed_cb()) == '' + assert context.tokens_wo_space_cb() == [] + + +@pytest.mark.parametrize( + ('text_before_cursor', 'expected'), + [ + ("select 'foo", True), + ('select "foo', True), + ('select `foo', False), + ('select foo', False), + ], +) +def test_is_single_or_double_quoted(text_before_cursor, expected): + context = _build_suggest_context( + None, + text_before_cursor, + None, + text_before_cursor, + empty_identifier(), + ) + assert _is_single_or_double_quoted(context) is expected + + +def test_parent_name_returns_identifier_parent(): + identifier = SimpleNamespace(get_parent_name=lambda: 'sch') + context = _build_suggest_context(None, '', None, '', identifier) + assert _parent_name(context) == 'sch' + + +def test_parent_name_returns_empty_list_without_parent(): + context = _build_suggest_context(None, '', None, '', empty_identifier()) + assert _parent_name(context) == [] + + +def test_tables_returns_extracted_tables_from_full_text(): + full_text = 'SELECT * FROM abc a, sch.def d' + context = _build_suggest_context(None, '', None, full_text, empty_identifier()) + assert _tables(context) == [ + (None, 'abc', 'a'), + ('sch', 'def', 'd'), + ] + + +def test_aliases_prefers_alias_and_falls_back_to_table_name(): + tables = [ + (None, 'abc', 'a'), + ('sch', 'def', ''), + ] + assert _aliases(tables) == ['a', 'def'] + + +@pytest.mark.parametrize( + ('word_before_cursor', 'expected'), + [ + ('9foo', True), + ('.foo', True), + ('foo', False), + (None, False), + ], +) +def test_word_starts_with_digit_or_dot(word_before_cursor, expected): + context = _build_suggest_context( + None, + '', + word_before_cursor, + '', + empty_identifier(), + ) + assert _word_starts_with_digit_or_dot(context) is expected + + +@pytest.mark.parametrize( + ('word_before_cursor', 'expected'), + [ + ("'foo", True), + ('"foo', True), + ('foo', False), + (None, False), + ], +) +def test_word_starts_with_quote(word_before_cursor, expected): + context = _build_suggest_context( + None, + '', + word_before_cursor, + '', + empty_identifier(), + ) + assert _word_starts_with_quote(context) is expected + + +def test_token_is_none_true_for_none_token(): + context = _build_suggest_context(None, '', None, '', empty_identifier()) + assert _token_is_none(context) is True + + +def test_token_is_none_false_for_non_none_token(): + context = _build_suggest_context('select', '', None, '', empty_identifier()) + assert _token_is_none(context) is False + + +@pytest.mark.parametrize( + ('token', 'expected'), + [ + ('', True), + ('select', False), + (None, True), + ], +) +def test_token_is_blank(token, expected): + context = _build_suggest_context(token, '', None, '', empty_identifier()) + assert _token_is_blank(context) is expected + + +@pytest.mark.parametrize( + ('token', 'values', 'expected'), + [ + ('select', ('select', 'where'), True), + ('show', ('select', 'where'), False), + (None, ('select',), False), + ], +) +def test_token_value_is(token, values, expected): + context = _build_suggest_context(token, '', None, '', empty_identifier()) + assert _token_value_is(context, *values) is expected + + +@pytest.mark.parametrize( + ('token', 'expected'), + [ + ('(', True), + ('any(', True), + ('select', False), + (None, False), + ], +) +def test_token_is_lparen(token, expected): + context = _build_suggest_context(token, '', None, '', empty_identifier()) + assert _token_is_lparen(context) is expected + + +@pytest.mark.parametrize( + ('token', 'text_before_cursor', 'full_text', 'expected'), + [ + (last_non_whitespace_token('SELECT * FROM foo JOIN '), 'SELECT * FROM foo JOIN ', 'SELECT * FROM foo JOIN ', True), + ('from', 'from ', 'from ', True), + ('truncate', 'truncate ', 'truncate ', True), + ('like', 'like ', 'create table new like ', True), + ('like', 'like ', 'select * from foo like ', False), + ('select', 'select ', 'select ', False), + ], +) +def test_token_is_relation_keyword(token, text_before_cursor, full_text, expected): + context = _build_suggest_context(token, text_before_cursor, None, full_text, empty_identifier()) + assert _token_is_relation_keyword(context) is expected + + +@pytest.mark.parametrize( + ('token', 'expected'), + [ + (',', True), + ('=', True), + ('and', True), + ('select', False), + (None, False), + ], +) +def test_token_is_binary_or_comma(token, expected): + context = _build_suggest_context(token, '', None, '', empty_identifier()) + assert _token_is_binary_or_comma(context) is expected + + +def test_emit_none_token(): + context = _build_suggest_context(None, '', None, '', empty_identifier()) + assert _emit_none_token(context) == [{'type': 'keyword'}] + + +def test_emit_blank_token(): + context = _build_suggest_context('', '', None, '', empty_identifier()) + assert _emit_blank_token(context) == [{'type': 'keyword'}, {'type': 'special'}] + + +def test_emit_star(): + context = _build_suggest_context('*', '', None, '', empty_identifier()) + assert _emit_star(context) == [{'type': 'keyword'}] + + +def test_emit_lparen_exists_where(): + text = 'SELECT * FROM foo WHERE EXISTS (' + context = _build_suggest_context('(', text, None, text, empty_identifier()) + assert _emit_lparen(context) == [{'type': 'keyword'}] + + +def test_emit_lparen_join_using(): + text = 'select * from abc inner join def using (' + context = _build_suggest_context('(', text, None, text, empty_identifier()) + assert _emit_lparen(context) == [{'type': 'column', 'tables': [(None, 'abc', None), (None, 'def', None)], 'drop_unique': True}] + + +def test_emit_lparen_show(): + text = 'SHOW (' + context = _build_suggest_context('(', text, None, text, empty_identifier()) + assert _emit_lparen(context) == [{'type': 'show'}] + + +def test_emit_lparen_function_argument_list(): + text = 'SELECT MAX(' + full_text = 'SELECT MAX( FROM tbl' + context = _build_suggest_context('(', text, None, full_text, empty_identifier()) + assert _emit_lparen(context) == [{'type': 'column', 'tables': [(None, 'tbl', None)]}] + + +def test_emit_procedure(): + context = _build_suggest_context('call', '', None, '', empty_identifier()) + assert _emit_procedure(context) == [{'type': 'procedure', 'schema': []}] + + +def test_emit_character_set(): + context = _build_suggest_context('set', '', None, '', empty_identifier()) + assert _emit_character_set(context) == [{'type': 'character_set'}] + + +def test_emit_column_for_tables(): + full_text = 'SELECT * FROM abc a, sch.def d' + context = _build_suggest_context('select', '', None, full_text, empty_identifier()) + assert _emit_column_for_tables(context) == [ + { + 'type': 'column', + 'tables': [ + (None, 'abc', 'a'), + ('sch', 'def', 'd'), + ], + } + ] + + +def test_emit_nothing(): + context = _build_suggest_context('as', '', None, '', empty_identifier()) + assert _emit_nothing(context) == [] + + +def test_emit_show(): + context = _build_suggest_context('show', '', None, '', empty_identifier()) + assert _emit_show(context) == [{'type': 'show'}] + + +def test_emit_to_for_change_statement(): + text = 'change master to ' + context = _build_suggest_context('to', text, None, text, empty_identifier()) + assert _emit_to(context) == [{'type': 'change'}] + + +def test_emit_to_for_non_change_statement(): + text = 'grant all on db.* to ' + context = _build_suggest_context('to', text, None, text, empty_identifier()) + assert _emit_to(context) == [{'type': 'user'}] + + +def test_emit_user(): + context = _build_suggest_context('user', '', None, '', empty_identifier()) + assert _emit_user(context) == [{'type': 'user'}] + + +def test_emit_collation(): + context = _build_suggest_context('collate', '', None, '', empty_identifier()) + assert _emit_collation(context) == [{'type': 'collation'}] + + +@pytest.mark.xfail +def test_emit_select_like_with_parent_filters_tables(): + identifier = SimpleNamespace(get_parent_name=lambda: 't1') + text = 'SELECT t1.' + full_text = 'SELECT t1. FROM tabl1 t1, tabl2 t2' + context = _build_suggest_context('select', text, None, full_text, identifier) + assert sorted_dicts(_emit_select_like(context)) == sorted_dicts([ + {'type': 'column', 'tables': [(None, 'tabl1', 't1')]}, + # xfail because these are also currently returned + # {'type': 'table', 'schema': 't1'}, + # {'type': 'view', 'schema': 't1'}, + # {'type': 'function', 'schema': 't1'}, + ]) + + +def test_emit_select_like_inside_backticks_adds_keyword(): + text = 'SELECT `a' + full_text = 'SELECT `a FROM tabl' + context = _build_suggest_context('select', text, None, full_text, empty_identifier()) + assert sorted_dicts(_emit_select_like(context)) == sorted_dicts([ + {'type': 'column', 'tables': [(None, 'tabl', None)]}, + {'type': 'function', 'schema': []}, + {'type': 'alias', 'aliases': ['tabl']}, + {'type': 'keyword'}, + ]) + + +def test_emit_select_like_default(): + text = 'SELECT ' + full_text = 'SELECT FROM tabl' + context = _build_suggest_context('select', text, None, full_text, empty_identifier()) + assert sorted_dicts(_emit_select_like(context)) == sorted_dicts([ + {'type': 'column', 'tables': [(None, 'tabl', None)]}, + {'type': 'function', 'schema': []}, + {'type': 'introducer'}, + {'type': 'alias', 'aliases': ['tabl']}, + ]) + + +def test_emit_relation_like_with_schema_parent(): + identifier = SimpleNamespace(get_parent_name=lambda: 'sch') + text = 'INSERT INTO sch.' + context = _build_suggest_context('into', text, None, text, identifier) + assert sorted_dicts(_emit_relation_like(context)) == sorted_dicts([ + {'type': 'table', 'schema': 'sch'}, + {'type': 'view', 'schema': 'sch'}, + ]) + + +def test_emit_relation_like_join_adds_database_and_join_flag(): + text = 'SELECT * FROM foo JOIN ' + token = last_non_whitespace_token(text) + context = _build_suggest_context(token, text, None, text, empty_identifier()) + assert sorted_dicts(_emit_relation_like(context)) == sorted_dicts([ + {'type': 'database'}, + {'type': 'table', 'schema': [], 'join': True}, + {'type': 'view', 'schema': []}, + ]) + + +def test_emit_relation_like_truncate_omits_view(): + text = 'TRUNCATE ' + context = _build_suggest_context('truncate', text, None, text, empty_identifier()) + assert sorted_dicts(_emit_relation_like(context)) == sorted_dicts([ + {'type': 'database'}, + {'type': 'table', 'schema': []}, + ]) + + +def test_emit_relation_name_with_schema_parent(): + identifier = SimpleNamespace(get_parent_name=lambda: 'sch') + context = _build_suggest_context('table', '', None, '', identifier) + assert _emit_relation_name(context) == [{'type': 'table', 'schema': 'sch'}] + + +def test_emit_relation_name_without_schema_parent(): + context = _build_suggest_context('view', '', None, '', empty_identifier()) + assert _emit_relation_name(context) == [{'type': 'schema'}, {'type': 'view', 'schema': []}] + + +@pytest.mark.xfail +def test_emit_on_with_parent_filters_tables(): + identifier = SimpleNamespace(get_parent_name=lambda: 'a') + text = 'SELECT * FROM abc a JOIN def d ON a.' + context = _build_suggest_context('on', text, None, text, identifier) + assert sorted_dicts(_emit_on(context)) == sorted_dicts([ + {'type': 'column', 'tables': [(None, 'abc', 'a')]}, + # xfail because these currently also are returned + # {'type': 'table', 'schema': 'a'}, + # {'type': 'view', 'schema': 'a'}, + # {'type': 'function', 'schema': 'a'}, + ]) + + +def test_emit_on_without_parent_uses_fk_join_and_aliases(): + text = 'select a.x, b.y from abc a join bcd b on ' + context = _build_suggest_context('on', text, None, text, empty_identifier()) + assert _emit_on(context) == [ + {'type': 'fk_join', 'tables': [(None, 'abc', 'a'), (None, 'bcd', 'b')]}, + {'type': 'alias', 'aliases': ['a', 'b']}, + ] + + +def test_emit_on_without_visible_tables_adds_database_and_table(): + text = 'grant select on ' + context = _build_suggest_context('on', text, None, text, empty_identifier()) + assert _emit_on(context) == [ + {'type': 'fk_join', 'tables': []}, + {'type': 'alias', 'aliases': []}, + {'type': 'database'}, + {'type': 'table', 'schema': []}, + ] + + +def test_emit_database(): + context = _build_suggest_context('database', '', None, '', empty_identifier()) + assert _emit_database(context) == [{'type': 'database'}] + + +def test_emit_where_token_returns_charset_suggestion_when_available(monkeypatch): + text = 'select * from tabl where foo = ' + where_token = next(token for token in sqlparse.parse(text)[0].tokens if isinstance(token, sqlparse.sql.Where)) + context = _build_suggest_context(where_token, text, None, text, empty_identifier()) + suggestion = [{'type': 'character_set'}] + + monkeypatch.setattr(completion_engine, '_charset_suggestion', lambda _tokens: suggestion) + monkeypatch.setattr( + completion_engine, + 'suggest_based_on_last_token', + lambda *_args: pytest.fail('suggest_based_on_last_token should not be called'), + ) + + assert _emit_where_token(context) == suggestion + + +def test_emit_where_token_prepends_enum_value_for_where_fallback(monkeypatch): + text = 'select * from tabl where foo = ' + where_token = next(token for token in sqlparse.parse(text)[0].tokens if isinstance(token, sqlparse.sql.Where)) + context = _build_suggest_context(where_token, text, None, text, empty_identifier()) + prev_keyword = SimpleNamespace(value='where') + enum_suggestion = {'type': 'enum_value'} + fallback = [{'type': 'keyword'}] + + monkeypatch.setattr(completion_engine, '_charset_suggestion', lambda _tokens: None) + monkeypatch.setattr(completion_engine, 'find_prev_keyword', lambda _text: (prev_keyword, 'select * from tabl where ')) + monkeypatch.setattr(completion_engine, '_enum_value_suggestion', lambda _original, _full: enum_suggestion) + monkeypatch.setattr(completion_engine, 'suggest_based_on_last_token', lambda *_args: fallback) + + assert _emit_where_token(context) == [enum_suggestion] + fallback + + +def test_emit_where_token_returns_fallback_for_non_where_keyword(monkeypatch): + text = 'select * from tabl where foo = ' + where_token = next(token for token in sqlparse.parse(text)[0].tokens if isinstance(token, sqlparse.sql.Where)) + context = _build_suggest_context(where_token, text, None, text, empty_identifier()) + fallback = [{'type': 'keyword'}] + + monkeypatch.setattr(completion_engine, '_charset_suggestion', lambda _tokens: None) + monkeypatch.setattr( + completion_engine, + 'find_prev_keyword', + lambda _text: (SimpleNamespace(value='from'), 'select * from tabl '), + ) + monkeypatch.setattr(completion_engine, '_enum_value_suggestion', lambda _original, _full: {'type': 'enum_value'}) + monkeypatch.setattr(completion_engine, 'suggest_based_on_last_token', lambda *_args: fallback) + + assert _emit_where_token(context) == fallback + + +def test_emit_where_token_handles_convert_using_with_trailing_partial_name(monkeypatch): + text = 'select * from tabl where convert(foo using utf' + where_token = next(token for token in sqlparse.parse(text)[0].tokens if isinstance(token, sqlparse.sql.Where)) + context = _build_suggest_context(where_token, text, None, text, empty_identifier()) + + monkeypatch.setattr( + completion_engine, + 'suggest_based_on_last_token', + lambda *_args: pytest.fail('suggest_based_on_last_token should not be called'), + ) + + assert _emit_where_token(context) == [{'type': 'character_set'}] + + +def test_emit_binary_or_comma_prepends_enum_value_for_where_fallback(monkeypatch): + text = 'select * from tabl where foo = ' + context = _build_suggest_context('=', text, None, text, empty_identifier()) + prev_keyword = SimpleNamespace(value='where') + enum_suggestion = {'type': 'enum_value'} + fallback = [{'type': 'column', 'tables': [(None, 'tabl', None)]}] + + monkeypatch.setattr(completion_engine, 'find_prev_keyword', lambda _text: (prev_keyword, 'select * from tabl where ')) + monkeypatch.setattr(completion_engine, '_enum_value_suggestion', lambda _original, _full: enum_suggestion) + monkeypatch.setattr(completion_engine, 'suggest_based_on_last_token', lambda *_args: fallback) + + assert _emit_binary_or_comma(context) == [enum_suggestion] + fallback + + +def test_emit_binary_or_comma_uses_keyword_fallback_for_nonprogressing_rewind(monkeypatch): + text = 'select * from tabl where foo = ' + context = _build_suggest_context(',', text, None, text, empty_identifier()) + prev_keyword = SimpleNamespace(value='where') + fallback = [{'type': 'keyword'}] + + monkeypatch.setattr(completion_engine, 'find_prev_keyword', lambda _text: (prev_keyword, text.rstrip())) + monkeypatch.setattr(completion_engine, '_enum_value_suggestion', lambda _original, _full: None) + monkeypatch.setattr( + completion_engine, + 'suggest_based_on_last_token', + lambda *_args: pytest.fail('suggest_based_on_last_token should not be called'), + ) + monkeypatch.setattr(completion_engine, '_keyword_suggestions', lambda: fallback) + + assert _emit_binary_or_comma(context) == fallback + + +def test_emit_binary_or_comma_returns_rewound_fallback_without_where_enum(monkeypatch): + text = 'select * from tabl and ' + context = _build_suggest_context('and', text, None, text, empty_identifier()) + fallback = [{'type': 'keyword'}] + + monkeypatch.setattr( + completion_engine, + 'find_prev_keyword', + lambda _text: (SimpleNamespace(value='from'), 'select * from '), + ) + monkeypatch.setattr(completion_engine, '_enum_value_suggestion', lambda _original, _full: {'type': 'enum_value'}) + monkeypatch.setattr(completion_engine, 'suggest_based_on_last_token', lambda *_args: fallback) + + assert _emit_binary_or_comma(context) == fallback + + +@pytest.mark.parametrize( + ('token', 'expected'), + [ + (None, False), + (SimpleNamespace(value='where'), True), + (SimpleNamespace(value='HAVING'), True), + (SimpleNamespace(value='from'), False), + (SimpleNamespace(value=''), False), + ], +) +def test_is_where_or_having(token, expected): + assert _is_where_or_having(token) is expected + + +@pytest.mark.parametrize('exc_type', [TypeError, AttributeError]) +def test_suggest_type_returns_keyword_suggestions_when_sqlparse_parse_errors(monkeypatch, exc_type): + monkeypatch.setattr(completion_engine.sqlparse, 'parse', lambda _text: (_ for _ in ()).throw(exc_type())) + + assert suggest_type('select 1', 'select 1') == [{'type': 'keyword'}] + + +@pytest.mark.parametrize('exc_type', [TypeError, AttributeError]) +def test_suggest_type_returns_keyword_suggestions_when_word_parse_errors(monkeypatch, exc_type): + parse_inputs: list[str] = [] + original_parse = sqlparse.parse + + def fake_parse(text: str): + parse_inputs.append(text) + if len(parse_inputs) == 1: + return [original_parse('select ')[0]] + raise exc_type() + + monkeypatch.setattr(completion_engine.sqlparse, 'parse', fake_parse) + + assert suggest_type('select foo', 'select foo') == [{'type': 'keyword'}] + assert parse_inputs == ['select ', 'foo'] + + +def test_suggest_type_dispatches_backslash_commands_to_suggest_special(monkeypatch): + parse_inputs: list[str] = [] + special_inputs: list[str] = [] + original_parse = sqlparse.parse + + def fake_parse(text: str): + parse_inputs.append(text) + return [original_parse('\\dt ')[0]] + + monkeypatch.setattr(completion_engine.sqlparse, 'parse', fake_parse) + monkeypatch.setattr( + completion_engine, + 'suggest_special', + lambda text: special_inputs.append(text) or [{'type': 'special'}], + ) + monkeypatch.setattr( + completion_engine, + 'suggest_based_on_last_token', + lambda *_args: [{'type': 'keyword'}], + ) + + suggestions = suggest_type('\\dt', '\\dt') + + assert parse_inputs == ['\\dt'] + assert special_inputs == ['\\dt'] + assert suggestions == [{'type': 'special'}] + + +@pytest.mark.parametrize( + ('text', 'expected'), + [ + ('\\', [{'type': 'special'}]), + ('use ', [{'type': 'database'}]), + ('connect ', [{'type': 'database'}]), + ('\\u ', [{'type': 'database'}]), + ('\\r ', [{'type': 'database'}]), + ('tableformat ', [{'type': 'table_format'}]), + ('redirectformat ', [{'type': 'table_format'}]), + ('\\T ', [{'type': 'table_format'}]), + ('\\Tr ', [{'type': 'table_format'}]), + ('\\f ', [{'type': 'favoritequery'}]), + ('\\fs ', [{'type': 'favoritequery'}]), + ('\\fd ', [{'type': 'favoritequery'}]), + ('\\dt ', [{'type': 'table', 'schema': []}, {'type': 'view', 'schema': []}, {'type': 'schema'}]), + ('\\dt+ ', [{'type': 'table', 'schema': []}, {'type': 'view', 'schema': []}, {'type': 'schema'}]), + ('\\. ', [{'type': 'file_name'}]), + ('source ', [{'type': 'file_name'}]), + ('\\o ', [{'type': 'file_name'}]), + ('\\once ', [{'type': 'file_name'}]), + ('tee ', [{'type': 'file_name'}]), + ('\\e ', [{'type': 'file_name'}]), + ('\\edit ', [{'type': 'file_name'}]), + ('\\llm ', [{'type': 'llm'}]), + ('\\ai ', [{'type': 'llm'}]), + ('pager ', [{'type': 'keyword'}, {'type': 'special'}]), + ], +) +def test_suggest_special(text, expected): + assert suggest_special(text) == expected + + +@pytest.mark.parametrize( + ('token', 'text_before_cursor', 'word_before_cursor', 'full_text', 'expected'), + [ + (None, '', None, '', [{'type': 'keyword'}]), + ('', '', None, '', [{'type': 'keyword'}, {'type': 'special'}]), + ('*', 'select *', None, 'select *', [{'type': 'keyword'}]), + ('as', 'select 1 as ', None, 'select 1 as ', []), + ('show', 'show ', None, 'show ', [{'type': 'show'}]), + ('to', 'grant all on db.* to ', None, 'grant all on db.* to ', [{'type': 'user'}]), + ('to', 'change master to ', None, 'change master to ', [{'type': 'change'}]), + ('where', 'select * from tabl where ', '9', 'select * from tabl where ', []), + ('where', 'select * from tabl where "fo', '"fo', 'select * from tabl where "fo', []), + ('where', "select * from tabl where 'fo", 'fo', "select * from tabl where 'fo", []), + ], +) +def test_suggest_based_on_last_token(token, text_before_cursor, word_before_cursor, full_text, expected): + suggestion = suggest_based_on_last_token( + token, + text_before_cursor, + word_before_cursor, + full_text, + empty_identifier(), + ) + assert suggestion == expected + + +def test_suggest_based_on_last_token_lparen_in_exists_where_suggests_keyword(): + text = 'SELECT * FROM foo WHERE EXISTS (' + suggestion = suggest_based_on_last_token('(', text, None, text, empty_identifier()) + assert suggestion == [{'type': 'keyword'}] + + +def test_suggest_based_on_last_token_lparen_in_where_any_suggests_columns_functions(): + text = 'SELECT * FROM tabl WHERE foo = ANY(' + suggestion = suggest_based_on_last_token('(', text, None, text, empty_identifier()) + assert sorted_dicts(suggestion) == sorted_dicts([ + {'type': 'alias', 'aliases': ['tabl']}, + {'type': 'column', 'tables': [(None, 'tabl', None)]}, + {'type': 'function', 'schema': []}, + {'type': 'introducer'}, + ]) + + +def test_suggest_based_on_last_token_lparen_after_join_using_suggests_common_columns(): + text = 'select * from abc inner join def using (' + suggestion = suggest_based_on_last_token('(', text, None, text, empty_identifier()) + assert suggestion == [{'type': 'column', 'tables': [(None, 'abc', None), (None, 'def', None)], 'drop_unique': True}] + + +def test_suggest_based_on_last_token_lparen_after_select_subquery_suggests_keyword(): + text = 'SELECT * FROM (' + suggestion = suggest_based_on_last_token('(', text, None, text, empty_identifier()) + assert suggestion == [{'type': 'keyword'}] + + +def test_suggest_based_on_last_token_lparen_after_show_suggests_show_items(): + text = 'SHOW (' + suggestion = suggest_based_on_last_token('(', text, None, text, empty_identifier()) + assert suggestion == [{'type': 'show'}] + + +def test_suggest_based_on_last_token_lparen_in_function_call_suggests_columns(): + text = 'SELECT MAX(' + full_text = 'SELECT MAX( FROM tbl' + suggestion = suggest_based_on_last_token('(', text, None, full_text, empty_identifier()) + assert suggestion == [{'type': 'column', 'tables': [(None, 'tbl', None)]}] + + +@pytest.mark.parametrize( + ('token', 'text_before_cursor', 'full_text', 'expected'), + [ + ('call', 'call ', 'call ', [{'type': 'procedure', 'schema': []}]), + ('set', 'character set', 'character set', [{'type': 'character_set'}]), + ('distinct', 'select distinct ', 'select distinct ', [{'type': 'column', 'tables': []}]), + ('database', 'drop database ', 'drop database ', [{'type': 'database'}]), + ('template', 'create database foo with template ', 'create database foo with template ', [{'type': 'database'}]), + ('collate', 'collate ', 'collate ', [{'type': 'collation'}]), + ('table', 'drop table ', 'drop table ', [{'type': 'schema'}, {'type': 'table', 'schema': []}]), + ('view', 'drop view ', 'drop view ', [{'type': 'schema'}, {'type': 'view', 'schema': []}]), + ('function', 'drop function ', 'drop function ', [{'type': 'schema'}, {'type': 'function', 'schema': []}]), + ], +) +def test_suggest_based_on_last_token_direct_keyword_branches(token, text_before_cursor, full_text, expected): + suggestion = suggest_based_on_last_token(token, text_before_cursor, None, full_text, empty_identifier()) + assert suggestion == expected + + +def test_suggest_based_on_last_token_relation_keyword_with_schema_parent(): + identifier = SimpleNamespace(get_parent_name=lambda: 'sch') + text = 'INSERT INTO sch.' + suggestion = suggest_based_on_last_token('into', text, None, text, identifier) + assert sorted_dicts(suggestion) == sorted_dicts([ + {'type': 'table', 'schema': 'sch'}, + {'type': 'view', 'schema': 'sch'}, + ]) + + +def test_suggest_based_on_last_token_join_keyword_marks_join_suggestions(): + text = 'SELECT * FROM foo JOIN ' + suggestion = suggest_based_on_last_token(last_non_whitespace_token(text), text, None, text, empty_identifier()) + assert sorted_dicts(suggestion) == sorted_dicts([ + {'type': 'database'}, + {'type': 'table', 'schema': [], 'join': True}, + {'type': 'view', 'schema': []}, + ]) + + +def test_suggest_based_on_last_token_like_in_create_table_suggests_relations(): + text = 'CREATE TABLE new LIKE ' + suggestion = suggest_based_on_last_token('like', text, None, text, empty_identifier()) + assert sorted_dicts(suggestion) == sorted_dicts([ + {'type': 'database'}, + {'type': 'table', 'schema': []}, + {'type': 'view', 'schema': []}, + ]) + + +@pytest.mark.xfail +def test_suggest_based_on_last_token_select_with_parent_identifier_filters_tables(): + identifier = SimpleNamespace(get_parent_name=lambda: 't1') + text = 'SELECT t1.' + full_text = 'SELECT t1. FROM tabl1 t1, tabl2 t2' + suggestion = suggest_based_on_last_token('select', text, None, full_text, identifier) + assert sorted_dicts(suggestion) == sorted_dicts([ + {'type': 'column', 'tables': [(None, 'tabl1', 't1')]}, + # xfail because these are currently also returned + # {'type': 'table', 'schema': 't1'}, + # {'type': 'view', 'schema': 't1'}, + # {'type': 'function', 'schema': 't1'}, + ]) + + +def test_suggest_based_on_last_token_select_inside_backticks_adds_keywords(): + text = 'SELECT `a' + full_text = 'SELECT `a FROM tabl' + suggestion = suggest_based_on_last_token('select', text, None, full_text, empty_identifier()) + assert sorted_dicts(suggestion) == sorted_dicts([ + {'type': 'column', 'tables': [(None, 'tabl', None)]}, + {'type': 'function', 'schema': []}, + {'type': 'alias', 'aliases': ['tabl']}, + {'type': 'keyword'}, + ]) + + +def test_suggest_based_on_last_token_on_without_parent_suggests_fk_join_and_aliases(): + text = 'select a.x, b.y from abc a join bcd b on ' + suggestion = suggest_based_on_last_token('on', text, None, text, empty_identifier()) + assert suggestion == [ + {'type': 'fk_join', 'tables': [(None, 'abc', 'a'), (None, 'bcd', 'b')]}, + {'type': 'alias', 'aliases': ['a', 'b']}, + ] + + +def test_suggest_based_on_last_token_on_without_tables_adds_database_and_table(): + text = 'grant select on ' + suggestion = suggest_based_on_last_token('on', text, None, text, empty_identifier()) + assert suggestion == [ + {'type': 'fk_join', 'tables': []}, + {'type': 'alias', 'aliases': []}, + {'type': 'database'}, + {'type': 'table', 'schema': []}, + ] + + +@pytest.mark.xfail +def test_suggest_based_on_last_token_on_with_parent_identifier_filters_tables(): + identifier = SimpleNamespace(get_parent_name=lambda: 'a') + text = 'SELECT * FROM abc a JOIN def d ON a.' + suggestion = suggest_based_on_last_token('on', text, None, text, identifier) + assert sorted_dicts(suggestion) == sorted_dicts([ + {'type': 'column', 'tables': [(None, 'abc', 'a')]}, + # xfail because these are currently also returned + # {'type': 'table', 'schema': 'a'}, + # {'type': 'view', 'schema': 'a'}, + # {'type': 'function', 'schema': 'a'}, + ]) + + +def test_suggest_based_on_last_token_binary_operand_in_where_prepends_enum_value(): + text = 'SELECT * FROM tabl WHERE foo = ' + suggestion = suggest_based_on_last_token('=', text, None, text, empty_identifier()) + assert sorted_dicts(suggestion) == sorted_dicts([ + {'type': 'enum_value', 'tables': [(None, 'tabl', None)], 'column': 'foo', 'parent': None}, + {'type': 'alias', 'aliases': ['tabl']}, + {'type': 'column', 'tables': [(None, 'tabl', None)]}, + {'type': 'function', 'schema': []}, + {'type': 'introducer'}, + ]) + + +def test_suggest_based_on_last_token_comma_recurses_to_select_suggestions(): + text = 'SELECT a, ' + full_text = 'SELECT a, FROM tabl' + suggestion = suggest_based_on_last_token(',', text, None, full_text, empty_identifier()) + assert sorted_dicts(suggestion) == sorted_dicts([ + {'type': 'alias', 'aliases': ['tabl']}, + {'type': 'column', 'tables': [(None, 'tabl', None)]}, + {'type': 'function', 'schema': []}, + {'type': 'introducer'}, + ]) + + +def test_suggest_based_on_last_token_nonprogressing_comma_falls_back_to_keyword(): + text = ',' + suggestion = suggest_based_on_last_token(',', text, None, text, empty_identifier()) + assert suggestion == [{'type': 'keyword'}] + + +@pytest.mark.parametrize( + ('identifier', 'schema', 'table', 'alias', 'expected'), + [ + ('t', None, 'tbl', 't', True), + ('tbl', None, 'tbl', 't', True), + ('sch.tbl', 'sch', 'tbl', 't', True), + ('other', 'sch', 'tbl', 't', False), + ('sch.other', 'sch', 'tbl', 't', False), + ('tbl', 'sch', 'other', 't', False), + ], +) +def test_identifies(identifier, schema, table, alias, expected): + assert identifies(identifier, schema, table, alias) is expected + + +@pytest.mark.parametrize( + "expression", + [ + "SELECT * FROM tabl WHERE foo IN (", + "SELECT * FROM tabl WHERE foo IN (bar, ", + ], +) +def test_where_in_suggests_columns(expression): + suggestions = suggest_type(expression, expression) + assert sorted_dicts(suggestions) == sorted_dicts([ + {"type": "alias", "aliases": ["tabl"]}, + {"type": "column", "tables": [(None, "tabl", None)]}, + {"type": "function", "schema": []}, + {"type": "introducer"}, + ]) + + +def test_where_equals_any_suggests_columns_or_keywords(): + text = "SELECT * FROM tabl WHERE foo = ANY(" + suggestions = suggest_type(text, text) + assert sorted_dicts(suggestions) == sorted_dicts([ + {"type": "alias", "aliases": ["tabl"]}, + {"type": "column", "tables": [(None, "tabl", None)]}, + {"type": "function", "schema": []}, + {"type": "introducer"}, + ]) + + +def test_where_convert_using_suggests_character_set(): + text = 'SELECT * FROM tabl WHERE CONVERT(foo USING ' + suggestions = suggest_type(text, text) + assert suggestions == [{"type": "character_set"}] + + +def test_where_cast_character_set_suggests_character_set(): + text = 'SELECT * FROM tabl WHERE CAST(foo AS CHAR CHARACTER SET ' + suggestions = suggest_type(text, text) + assert suggestions == [{"type": "character_set"}] + + +def test_lparen_suggests_cols(): + suggestion = suggest_type("SELECT MAX( FROM tbl", "SELECT MAX(") + assert suggestion == [{"type": "column", "tables": [(None, "tbl", None)]}] + + +def test_operand_inside_function_suggests_cols1(): + suggestion = suggest_type("SELECT MAX(col1 + FROM tbl", "SELECT MAX(col1 + ") + assert suggestion == [{"type": "column", "tables": [(None, "tbl", None)]}] + + +def test_operand_inside_function_suggests_cols2(): + suggestion = suggest_type("SELECT MAX(col1 + col2 + FROM tbl", "SELECT MAX(col1 + col2 + ") + assert suggestion == [{"type": "column", "tables": [(None, "tbl", None)]}] + + +def test_operand_inside_function_suggests_cols3(): + suggestion = suggest_type("SELECT MAX(col1 || FROM tbl", "SELECT MAX(col1 || ") + assert suggestion == [{"type": "column", "tables": [(None, "tbl", None)]}] + + +def test_operand_inside_function_suggests_cols4(): + suggestion = suggest_type("SELECT MAX(col1 LIKE FROM tbl", "SELECT MAX(col1 LIKE ") + assert suggestion == [{"type": "column", "tables": [(None, "tbl", None)]}] + + +def test_operand_inside_function_suggests_cols5(): + suggestion = suggest_type("SELECT MAX(col1 DIV FROM tbl", "SELECT MAX(col1 DIV ") + assert suggestion == [{"type": "column", "tables": [(None, "tbl", None)]}] + + +@pytest.mark.xfail +def test_arrow_op_inside_function_suggests_nothing(): + suggestion = suggest_type("SELECT MAX(col1-> FROM tbl", "SELECT MAX(col1->") + assert suggestion == [] + + +def test_select_suggests_cols_and_funcs(): + suggestions = suggest_type("SELECT ", "SELECT ") + assert sorted_dicts(suggestions) == sorted_dicts([ + {"type": "alias", "aliases": []}, + {"type": "column", "tables": []}, + {"type": "function", "schema": []}, + {"type": "introducer"}, + ]) + + +@pytest.mark.parametrize( + "expression", + [ + "SELECT * FROM ", + "INSERT INTO ", + "COPY ", + "UPDATE ", + "DESCRIBE ", + "DESC ", + "EXPLAIN ", + ], +) +def test_expression_suggests_tables_views_and_schemas(expression): + suggestions = suggest_type(expression, expression) + assert sorted_dicts(suggestions) == sorted_dicts([ + {"type": "table", "schema": []}, + {"type": "view", "schema": []}, + {"type": "database"}, + ]) + + +def test_join_expression_suggests_tables_views_and_schemas(): + expression = "SELECT * FROM foo JOIN " + suggestions = suggest_type(expression, expression) + assert sorted_dicts(suggestions) == sorted_dicts([ + {"type": "table", "schema": [], "join": True}, + {"type": "view", "schema": []}, + {"type": "database"}, + ]) + + +@pytest.mark.parametrize( + "expression", + [ + "SELECT * FROM sch.", + "INSERT INTO sch.", + "COPY sch.", + "UPDATE sch.", + "DESCRIBE sch.", + "DESC sch.", + "EXPLAIN sch.", + ], +) +def test_expression_suggests_qualified_tables_views_and_schemas(expression): + suggestions = suggest_type(expression, expression) + assert sorted_dicts(suggestions) == sorted_dicts([ + {"type": "table", "schema": "sch"}, + {"type": "view", "schema": "sch"}, + ]) + + +def test_join_expression_suggests_qualified_tables_views_and_schemas(): + expression = "SELECT * FROM foo JOIN sch." + suggestions = suggest_type(expression, expression) + assert sorted_dicts(suggestions) == sorted_dicts([ + {"type": "table", "schema": "sch", "join": True}, + {"type": "view", "schema": "sch"}, + ]) + + +def test_truncate_suggests_tables_and_schemas(): + suggestions = suggest_type("TRUNCATE ", "TRUNCATE ") + assert sorted_dicts(suggestions) == sorted_dicts([ + {"type": "table", "schema": []}, + {"type": "database"}, + ]) + + +def test_truncate_suggests_qualified_tables(): + suggestions = suggest_type("TRUNCATE sch.", "TRUNCATE sch.") + assert sorted_dicts(suggestions) == sorted_dicts([ + {"type": "table", "schema": "sch"}, + ]) + + +def test_distinct_suggests_cols(): + suggestions = suggest_type("SELECT DISTINCT ", "SELECT DISTINCT ") + assert suggestions == [{"type": "column", "tables": []}] + + +def test_col_comma_suggests_cols(): + suggestions = suggest_type("SELECT a, b, FROM tbl", "SELECT a, b,") + assert sorted_dicts(suggestions) == sorted_dicts([ + {"type": "alias", "aliases": ["tbl"]}, + {"type": "column", "tables": [(None, "tbl", None)]}, + {"type": "function", "schema": []}, + {"type": "introducer"}, + ]) + + +def test_table_comma_suggests_tables_and_schemas(): + suggestions = suggest_type("SELECT a, b FROM tbl1, ", "SELECT a, b FROM tbl1, ") + assert sorted_dicts(suggestions) == sorted_dicts([ + {"type": "database"}, + {"type": "table", "schema": []}, + {"type": "view", "schema": []}, + ]) + + +def test_into_suggests_tables_and_schemas(): + suggestion = suggest_type("INSERT INTO ", "INSERT INTO ") + assert sorted_dicts(suggestion) == sorted_dicts([ + {"type": "database"}, + {"type": "table", "schema": []}, + {"type": "view", "schema": []}, + ]) + + +def test_insert_into_lparen_suggests_cols(): + suggestions = suggest_type("INSERT INTO abc (", "INSERT INTO abc (") + assert suggestions == [{"type": "column", "tables": [(None, "abc", None)]}] + + +def test_insert_into_lparen_partial_text_suggests_cols(): + suggestions = suggest_type("INSERT INTO abc (i", "INSERT INTO abc (i") + assert suggestions == [{"type": "column", "tables": [(None, "abc", None)]}] + + +def test_insert_into_lparen_comma_suggests_cols(): + suggestions = suggest_type("INSERT INTO abc (id,", "INSERT INTO abc (id,") + assert suggestions == [{"type": "column", "tables": [(None, "abc", None)]}] + + +def test_partially_typed_col_name_suggests_col_names(): + suggestions = suggest_type("SELECT * FROM tabl WHERE col_n", "SELECT * FROM tabl WHERE col_n") + assert sorted_dicts(suggestions) == sorted_dicts([ + {"type": "alias", "aliases": ["tabl"]}, + {"type": "column", "tables": [(None, "tabl", None)]}, + {"type": "function", "schema": []}, + {"type": "introducer"}, + ]) + + +def test_dot_suggests_cols_of_a_table_or_schema_qualified_table(): + suggestions = suggest_type("SELECT tabl. FROM tabl", "SELECT tabl.") + assert sorted_dicts(suggestions) == sorted_dicts([ + {"type": "column", "tables": [(None, "tabl", None)]}, + {"type": "table", "schema": "tabl"}, + {"type": "view", "schema": "tabl"}, + {"type": "function", "schema": "tabl"}, + ]) + + +def test_dot_suggests_cols_of_an_alias(): + suggestions = suggest_type("SELECT t1. FROM tabl1 t1, tabl2 t2", "SELECT t1.") + assert sorted_dicts(suggestions) == sorted_dicts([ + {"type": "table", "schema": "t1"}, + {"type": "view", "schema": "t1"}, + {"type": "column", "tables": [(None, "tabl1", "t1")]}, + {"type": "function", "schema": "t1"}, + ]) + + +def test_dot_col_comma_suggests_cols_or_schema_qualified_table(): + suggestions = suggest_type("SELECT t1.a, t2. FROM tabl1 t1, tabl2 t2", "SELECT t1.a, t2.") + assert sorted_dicts(suggestions) == sorted_dicts([ + {"type": "column", "tables": [(None, "tabl2", "t2")]}, + {"type": "table", "schema": "t2"}, + {"type": "view", "schema": "t2"}, + {"type": "function", "schema": "t2"}, + ]) + + +@pytest.mark.parametrize( + "expression", + [ + "SELECT * FROM (", + "SELECT * FROM foo WHERE EXISTS (", + "SELECT * FROM foo WHERE bar AND NOT EXISTS (", + "SELECT 1 AS", + ], +) +def test_sub_select_suggests_keyword(expression): + suggestion = suggest_type(expression, expression) + assert suggestion == [{"type": "keyword"}] + + +@pytest.mark.parametrize( + "expression", + [ + "SELECT * FROM (S", + "SELECT * FROM foo WHERE EXISTS (S", + "SELECT * FROM foo WHERE bar AND NOT EXISTS (S", + ], +) +def test_sub_select_partial_text_suggests_keyword(expression): + suggestion = suggest_type(expression, expression) + assert suggestion == [{"type": "keyword"}] + + +def test_outer_table_reference_in_exists_subquery_suggests_columns(): + q = "SELECT * FROM foo f WHERE EXISTS (SELECT 1 FROM bar WHERE f." + suggestions = suggest_type(q, q) + assert suggestions == [ + {"type": "column", "tables": [(None, "foo", "f")]}, + {"type": "table", "schema": "f"}, + {"type": "view", "schema": "f"}, + {"type": "function", "schema": "f"}, + ] + + +@pytest.mark.parametrize( + "expression", + [ + "SELECT * FROM (SELECT * FROM ", + "SELECT * FROM foo WHERE EXISTS (SELECT * FROM ", + "SELECT * FROM foo WHERE bar AND NOT EXISTS (SELECT * FROM ", + ], +) +def test_sub_select_table_name_completion(expression): + suggestion = suggest_type(expression, expression) + assert sorted_dicts(suggestion) == sorted_dicts([ + {"type": "database"}, + {"type": "table", "schema": []}, + {"type": "view", "schema": []}, + ]) + + +def test_sub_select_col_name_completion(): + suggestions = suggest_type("SELECT * FROM (SELECT FROM abc", "SELECT * FROM (SELECT ") + assert sorted_dicts(suggestions) == sorted_dicts([ + {"type": "alias", "aliases": ["abc"]}, + {"type": "column", "tables": [(None, "abc", None)]}, + {"type": "function", "schema": []}, + {"type": "introducer"}, + ]) + + +@pytest.mark.xfail +def test_sub_select_multiple_col_name_completion(): + suggestions = suggest_type("SELECT * FROM (SELECT a, FROM abc", "SELECT * FROM (SELECT a, ") + assert sorted_dicts(suggestions) == sorted_dicts([ + {"type": "column", "tables": [(None, "abc", None)]}, + {"type": "function", "schema": []}, + {"type": "introducer"}, + ]) + + +def test_sub_select_dot_col_name_completion(): + suggestions = suggest_type("SELECT * FROM (SELECT t. FROM tabl t", "SELECT * FROM (SELECT t.") + assert sorted_dicts(suggestions) == sorted_dicts([ + {"type": "column", "tables": [(None, "tabl", "t")]}, + {"type": "table", "schema": "t"}, + {"type": "view", "schema": "t"}, + {"type": "function", "schema": "t"}, + ]) + + +@pytest.mark.parametrize("join_type", ["", "INNER", "LEFT", "RIGHT OUTER"]) +@pytest.mark.parametrize("tbl_alias", ["", "foo"]) +def test_join_suggests_tables_and_schemas(tbl_alias, join_type): + text = f"SELECT * FROM abc {tbl_alias} {join_type} JOIN " + suggestion = suggest_type(text, text) + assert sorted_dicts(suggestion) == sorted_dicts([ + {"type": "database"}, + {"type": "table", "schema": [], "join": True}, + {"type": "view", "schema": []}, + ]) + + +@pytest.mark.parametrize( + "sql", + [ + "SELECT * FROM abc a JOIN def d ON a.", + "SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.", + ], +) +def test_join_alias_dot_suggests_cols1(sql): + suggestions = suggest_type(sql, sql) + assert sorted_dicts(suggestions) == sorted_dicts([ + {"type": "column", "tables": [(None, "abc", "a")]}, + {"type": "table", "schema": "a"}, + {"type": "view", "schema": "a"}, + {"type": "function", "schema": "a"}, + ]) + + +@pytest.mark.parametrize( + "sql", + [ + "SELECT * FROM abc a JOIN def d ON a.id = d.", + "SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.id2 = d.", + ], +) +def test_join_alias_dot_suggests_cols2(sql): + suggestions = suggest_type(sql, sql) + assert sorted_dicts(suggestions) == sorted_dicts([ + {"type": "column", "tables": [(None, "def", "d")]}, + {"type": "table", "schema": "d"}, + {"type": "view", "schema": "d"}, + {"type": "function", "schema": "d"}, + ]) + + +@pytest.mark.parametrize( + "sql", + [ + "select a.x, b.y from abc a join bcd b on ", + "select a.x, b.y from abc a join bcd b on a.id = b.id OR ", + "select a.x, b.y from abc a join bcd b on a.id = b.id + ", + "select a.x, b.y from abc a join bcd b on a.id = b.id < ", + ], +) +def test_on_suggests_aliases(sql): + suggestions = suggest_type(sql, sql) + assert suggestions == [ + {"type": "fk_join", "tables": [(None, "abc", "a"), (None, "bcd", "b")]}, + {"type": "alias", "aliases": ["a", "b"]}, + ] + + +@pytest.mark.parametrize( + "sql", + [ + "select abc.x, bcd.y from abc join bcd on ", + "select abc.x, bcd.y from abc join bcd on abc.id = bcd.id AND ", + ], +) +def test_on_suggests_tables(sql): + suggestions = suggest_type(sql, sql) + assert suggestions == [ + {"type": "fk_join", "tables": [(None, "abc", None), (None, "bcd", None)]}, + {"type": "alias", "aliases": ["abc", "bcd"]}, + ] + + +@pytest.mark.parametrize( + "sql", + [ + "select a.x, b.y from abc a join bcd b on a.id = ", + "select a.x, b.y from abc a join bcd b on a.id = b.id AND a.id2 = ", + ], +) +def test_on_suggests_aliases_right_side(sql): + suggestions = suggest_type(sql, sql) + assert suggestions == [ + {"type": "fk_join", "tables": [(None, "abc", "a"), (None, "bcd", "b")]}, + {"type": "alias", "aliases": ["a", "b"]}, + ] + + +@pytest.mark.parametrize( + "sql", + [ + "select abc.x, bcd.y from abc join bcd on ", + "select abc.x, bcd.y from abc join bcd on abc.id = bcd.id and ", + ], +) +def test_on_suggests_tables_right_side(sql): + suggestions = suggest_type(sql, sql) + assert suggestions == [ + {"type": "fk_join", "tables": [(None, "abc", None), (None, "bcd", None)]}, + {"type": "alias", "aliases": ["abc", "bcd"]}, + ] + + +@pytest.mark.parametrize("col_list", ["", "col1, "]) +def test_join_using_suggests_common_columns(col_list): + text = "select * from abc inner join def using (" + col_list + assert suggest_type(text, text) == [{"type": "column", "tables": [(None, "abc", None), (None, "def", None)], "drop_unique": True}] + + +@pytest.mark.parametrize( + "sql", + [ + "SELECT * FROM abc a JOIN def d ON a.id = d.id JOIN ghi g ON g.", + "SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.id2 = d.id2 JOIN ghi g ON d.id = g.id AND g.", + ], +) +def test_two_join_alias_dot_suggests_cols1(sql): + suggestions = suggest_type(sql, sql) + assert sorted_dicts(suggestions) == sorted_dicts([ + {"type": "column", "tables": [(None, "ghi", "g")]}, + {"type": "table", "schema": "g"}, + {"type": "view", "schema": "g"}, + {"type": "function", "schema": "g"}, + ]) + + +def test_2_statements_2nd_current(): + suggestions = suggest_type("select * from a; select * from ", "select * from a; select * from ") + assert sorted_dicts(suggestions) == sorted_dicts([ + {"type": "table", "schema": []}, + {"type": "view", "schema": []}, + {"type": "database"}, + ]) + + suggestions = suggest_type("select * from a; select from b", "select * from a; select ") + assert sorted_dicts(suggestions) == sorted_dicts([ + {"type": "alias", "aliases": ["b"]}, + {"type": "column", "tables": [(None, "b", None)]}, + {"type": "function", "schema": []}, + {"type": "introducer"}, + ]) + + # Should work even if first statement is invalid + suggestions = suggest_type("select * from; select * from ", "select * from; select * from ") + assert sorted_dicts(suggestions) == sorted_dicts([ + {"type": "table", "schema": []}, + {"type": "view", "schema": []}, + {"type": "database"}, + ]) + + +def test_2_statements_1st_current(): + suggestions = suggest_type("select * from ; select * from b", "select * from ") + assert sorted_dicts(suggestions) == sorted_dicts([ + {"type": "database"}, + {"type": "table", "schema": []}, + {"type": "view", "schema": []}, + ]) + + suggestions = suggest_type("select from a; select * from b", "select ") + assert sorted_dicts(suggestions) == sorted_dicts([ + {"type": "alias", "aliases": ["a"]}, + {"type": "column", "tables": [(None, "a", None)]}, + {"type": "function", "schema": []}, + {"type": "introducer"}, + ]) + + +def test_3_statements_2nd_current(): + suggestions = suggest_type("select * from a; select * from ; select * from c", "select * from a; select * from ") + assert sorted_dicts(suggestions) == sorted_dicts([ + {"type": "database"}, + {"type": "table", "schema": []}, + {"type": "view", "schema": []}, + ]) + + suggestions = suggest_type("select * from a; select from b; select * from c", "select * from a; select ") + assert sorted_dicts(suggestions) == sorted_dicts([ + {"type": "alias", "aliases": ["b"]}, + {"type": "column", "tables": [(None, "b", None)]}, + {"type": "function", "schema": []}, + {"type": "introducer"}, + ]) + + +def test_create_db_with_template(): + suggestions = suggest_type("create database foo with template ", "create database foo with template ") + + assert sorted_dicts(suggestions) == sorted_dicts([{"type": "database"}]) + + +@pytest.mark.parametrize("initial_text", ["", " ", "\t \t"]) +def test_specials_included_for_initial_completion(initial_text): + suggestions = suggest_type(initial_text, initial_text) + + assert sorted_dicts(suggestions) == sorted_dicts([{"type": "keyword"}, {"type": "special"}]) + + +@pytest.mark.parametrize('initial_text', ['REDIRECT']) +def test_specials_included_with_caps(initial_text): + suggestions = suggest_type(initial_text, initial_text) + + assert sorted_dicts(suggestions) == sorted_dicts([{'type': 'keyword'}, {'type': 'special'}]) + + +def test_specials_not_included_after_initial_token(): + suggestions = suggest_type("create table foo (dt d", "create table foo (dt d") + + assert sorted_dicts(suggestions) == sorted_dicts([{"type": "keyword"}]) + + +def test_drop_schema_qualified_table_suggests_only_tables(): + text = "DROP TABLE schema_name.table_name" + suggestions = suggest_type(text, text) + assert suggestions == [{"type": "table", "schema": "schema_name"}] + + +@pytest.mark.parametrize("text", [",", " ,", "sel ,"]) +def test_handle_pre_completion_comma_gracefully(text): + suggestions = suggest_type(text, text) + + assert iter(suggestions) + + +def test_cross_join(): + text = "select * from v1 cross join v2 JOIN v1.id, " + suggestions = suggest_type(text, text) + assert sorted_dicts(suggestions) == sorted_dicts([ + {"type": "database"}, + {"type": "table", "schema": [], "join": True}, + {"type": "view", "schema": []}, + ]) + + +@pytest.mark.parametrize( + "expression", + [ + "SELECT 1 AS ", + "SELECT 1 FROM tabl AS ", + ], +) +def test_after_as(expression): + suggestions = suggest_type(expression, expression) + assert set(suggestions) == set() + + +@pytest.mark.parametrize( + "expression", + [ + "\\. ", + "select 1; \\. ", + "select 1;\\. ", + "select 1 ; \\. ", + "source ", + "truncate table test; source ", + "truncate table test ; source ", + "truncate table test;source ", + ], +) +def test_source_is_file(expression): + # "source" has to be registered by hand because that usually happens inside MyCLI in mycli/main.py + special.register_special_command( + ..., + 'source', + '\\. ', + 'Execute commands from file.', + aliases=[special.SpecialCommandAlias('\\.', case_sensitive=False)], + ) + suggestions = suggest_type(expression, expression) + assert suggestions == [{"type": "file_name"}] + + +@pytest.mark.parametrize( + "expression", + [ + "\\f ", + ], +) +def test_favorite_name_suggestion(expression): + suggestions = suggest_type(expression, expression) + assert suggestions == [{"type": "favoritequery"}] + + +def test_order_by(): + text = "select * from foo order by " + suggestions = suggest_type(text, text) + assert suggestions == [{"tables": [(None, "foo", None)], "type": "column"}] + + +def test_quoted_where(): + text = "'where i=';" + suggestions = suggest_type(text, text) + assert suggestions == [{"type": "keyword"}] + + +def test_find_doubled_backticks_none(): + text = 'select `ab`' + assert _find_doubled_backticks(text) == [] + + +def test_find_doubled_backticks_some(): + text = 'select `a``b`' + assert _find_doubled_backticks(text) == [9, 10] + + +def test_inside_quotes_01(): + text = "select '" + assert is_inside_quotes(text, len(text)) == 'single' + + +def test_inside_quotes_02(): + text = "select '\\'" + assert is_inside_quotes(text, len(text)) == 'single' + + +def test_inside_quotes_03(): + text = "select '`" + assert is_inside_quotes(text, len(text)) == 'single' + + +def test_inside_quotes_04(): + text = 'select "' + assert is_inside_quotes(text, len(text)) == 'double' + + +def test_inside_quotes_05(): + text = 'select "\\"\'' + assert is_inside_quotes(text, len(text)) == 'double' + + +def test_inside_quotes_06(): + text = 'select ""' + assert is_inside_quotes(text, len(text)) is False + + +@pytest.mark.parametrize( + ["text", "position", "expected"], + [ + ("select `'", len("select `'"), 'backtick'), + ("select `' ", len("select `' "), 'backtick'), + ("select `'", -1, 'backtick'), + ("select `'", -2, False), + ('select `ab` ', -1, False), + ('select `ab` ', -2, 'backtick'), + ('select `a``b` ', -1, False), + ('select `a``b` ', -2, 'backtick'), + ('select `a``b` ', -3, 'backtick'), + ('select `a``b` ', -4, 'backtick'), + ('select `a``b` ', -5, 'backtick'), + ('select `a``b` ', -6, 'backtick'), + ('select `a``b` ', -7, False), + ] +) # fmt: skip +def test_inside_quotes_backtick_01(text, position, expected): + assert is_inside_quotes(text, position) == expected + + +def test_inside_quotes_backtick_02(): + """Empty backtick pairs are treated as a doubled (escaped) backtick. + This is okay because it is invalid SQL, and we don't have to complete on it. + """ + text = 'select ``' + assert is_inside_quotes(text, -1) is False + + +def test_inside_quotes_backtick_03(): + """Empty backtick pairs are treated as a doubled (escaped) backtick. + This is okay because it is invalid SQL, and we don't have to complete on it. + """ + text = 'select ``' + assert is_inside_quotes(text, -2) is False diff --git a/test/pytests/test_completion_refresher.py b/test/pytests/test_completion_refresher.py new file mode 100644 index 00000000..afb9f252 --- /dev/null +++ b/test/pytests/test_completion_refresher.py @@ -0,0 +1,440 @@ +# type: ignore + +import time +from types import SimpleNamespace +from unittest.mock import Mock, patch + +import pytest + +import mycli.completion_refresher as completion_refresher + + +@pytest.fixture +def refresher(): + return completion_refresher.CompletionRefresher() + + +class FakeThread: + def __init__(self, target, args, name) -> None: + self.target = target + self.args = args + self.name = name + self.daemon = False + self.started = False + self.alive = False + + def start(self) -> None: + self.started = True + self.alive = True + + def run_target(self) -> None: + try: + self.target(*self.args) + finally: + self.alive = False + + def is_alive(self) -> bool: + return self.alive + + +def make_sqlexecute() -> SimpleNamespace: + return SimpleNamespace( + dbname='db', + user='user', + password='pw', + host='host', + port=3306, + socket='/tmp/mysql.sock', + character_set='utf8mb4', + local_infile=False, + ssl={'ca': 'ca.pem'}, + ssh_user='ssh-user', + ssh_host='ssh-host', + ssh_port=22, + ssh_password='ssh-pw', + ssh_key_filename='id_rsa', + ) + + +def test_ctor(refresher) -> None: + assert len(refresher.refreshers) > 0 + assert list(refresher.refreshers.keys()) == [ + "databases", + "schemata", + "tables", + "foreign_keys", + "enum_values", + "users", + "functions", + "procedures", + 'character_sets', + 'collations', + "special_commands", + "show_commands", + "keywords", + ] + + +def test_refresh_called_once(refresher): + """ + + :param refresher: + :return: + """ + callbacks = Mock() + sqlexecute = Mock() + + with patch.object(refresher, "_bg_refresh") as bg_refresh: + actual = refresher.refresh(sqlexecute, callbacks) + time.sleep(1) # Wait for the thread to work. + assert actual[0].preamble is None + assert actual[0].header is None + assert actual[0].rows is None + assert actual[0].status == "Auto-completion refresh started in the background." + bg_refresh.assert_called_with(sqlexecute, callbacks, {}) + + +def test_refresh_called_twice(refresher): + """If refresh is called a second time, it should be restarted. + + :param refresher: + :return: + + """ + callbacks = Mock() + + sqlexecute = Mock() + + def dummy_bg_refresh(*args): + time.sleep(3) # seconds + + refresher._bg_refresh = dummy_bg_refresh + + actual1 = refresher.refresh(sqlexecute, callbacks) + time.sleep(1) # Wait for the thread to work. + assert actual1[0].preamble is None + assert actual1[0].header is None + assert actual1[0].rows is None + assert actual1[0].status == "Auto-completion refresh started in the background." + + actual2 = refresher.refresh(sqlexecute, callbacks) + time.sleep(1) # Wait for the thread to work. + assert actual2[0].preamble is None + assert actual2[0].header is None + assert actual2[0].rows is None + assert actual2[0].status == "Auto-completion refresh restarted." + assert refresher._completer_thread is not None + refresher._completer_thread.join() + + +def test_refresh_with_callbacks(refresher): + """Callbacks must be called. + + :param refresher: + + """ + callbacks = [Mock()] + sqlexecute_class = Mock() + sqlexecute = Mock() + + with patch("mycli.completion_refresher.SQLExecute", sqlexecute_class): + # Set refreshers to 0: we're not testing refresh logic here + refresher.refreshers = {} + refresher.refresh(sqlexecute, callbacks) + time.sleep(1) # Wait for the thread to work. + assert callbacks[0].call_count == 1 + + +def test_refresh_starts_background_thread(monkeypatch, refresher) -> None: + calls: list[tuple[object, object, dict]] = [] + + def fake_bg_refresh(executor, callbacks, options) -> None: + calls.append((executor, callbacks, options)) + + monkeypatch.setattr(completion_refresher.threading, 'Thread', FakeThread) + monkeypatch.setattr(refresher, '_bg_refresh', fake_bg_refresh) + + sqlexecute = Mock() + callbacks = Mock() + + actual = refresher.refresh(sqlexecute, callbacks) + + assert actual[0].status == "Auto-completion refresh started in the background." + assert refresher._completer_thread is not None + assert refresher._completer_thread.name == "completion_refresh" + assert refresher._completer_thread.daemon is True + assert refresher._completer_thread.started is True + assert refresher.is_refreshing() is True + assert calls == [] + + refresher._completer_thread.run_target() + assert calls == [(sqlexecute, callbacks, {})] + assert refresher.is_refreshing() is False + + +def test_refresh_passes_explicit_completer_options(monkeypatch, refresher) -> None: + calls: list[tuple[object, object, dict]] = [] + + def fake_bg_refresh(executor, callbacks, options) -> None: + calls.append((executor, callbacks, options)) + + monkeypatch.setattr(completion_refresher.threading, 'Thread', FakeThread) + monkeypatch.setattr(refresher, '_bg_refresh', fake_bg_refresh) + + sqlexecute = Mock() + callbacks = Mock() + options = {'smart_completion': True} + + refresher.refresh(sqlexecute, callbacks, options) + refresher._completer_thread.run_target() + + assert calls == [(sqlexecute, callbacks, options)] + + +def test_refresh_while_refreshing_restarts(monkeypatch, refresher) -> None: + thread_calls: list[tuple[object, object, str]] = [] + + def fail_thread(*, target, args, name): + thread_calls.append((target, args, name)) + return FakeThread(target, args, name) + + monkeypatch.setattr(completion_refresher.threading, 'Thread', fail_thread) + existing_thread = SimpleNamespace(is_alive=lambda: True) + refresher._completer_thread = existing_thread + + actual = refresher.refresh(Mock(), Mock()) + + assert actual[0].status == "Auto-completion refresh restarted." + assert refresher._restart_refresh.is_set() is True + assert refresher._completer_thread is existing_thread + assert thread_calls == [] + + +def test_bg_refresh_restarts_wraps_callbacks_and_closes(monkeypatch, refresher) -> None: + completers: list[SimpleNamespace] = [] + executor_inits: list[tuple[object, ...]] = [] + executors: list[object] = [] + refresher_calls: list[str] = [] + callback_calls: list[tuple[str, SimpleNamespace]] = [] + event_order: list[str] = [] + + class FakeCompleter: + tidb_functions = ['tidb-func'] + tidb_keywords = ['tidb-keyword'] + + def __init__(self, **options) -> None: + self.options = options + completers.append(self) + + class FakeExecutor: + def __init__(self, *args) -> None: + executor_inits.append(args) + self.closed = False + executors.append(self) + + def close(self) -> None: + self.closed = True + event_order.append('close') + + def first_refresher(completer, executor) -> None: + refresher_calls.append('first') + event_order.append('refresher:first') + if refresher_calls == ['first']: + refresher._restart_refresh.set() + + def second_refresher(completer, executor) -> None: + refresher_calls.append('second') + event_order.append('refresher:second') + + def first_callback(completer) -> None: + callback_calls.append(('first', completer)) + event_order.append('callback:first') + + def second_callback(completer) -> None: + callback_calls.append(('second', completer)) + event_order.append('callback:second') + + monkeypatch.setattr(completion_refresher, 'SQLCompleter', FakeCompleter) + monkeypatch.setattr(completion_refresher, 'SQLExecute', FakeExecutor) + refresher.refreshers = { + 'first': first_refresher, + 'second': second_refresher, + } + + sqlexecute = make_sqlexecute() + refresher._bg_refresh(sqlexecute, [first_callback, second_callback], {'smart_completion': True}) + + assert len(completers) == 1 + assert completers[0].options == {'smart_completion': True} + assert executor_inits == [ + ( + 'db', + 'user', + 'pw', + 'host', + 3306, + '/tmp/mysql.sock', + 'utf8mb4', + False, + {'ca': 'ca.pem'}, + 'ssh-user', + 'ssh-host', + 22, + 'ssh-pw', + 'id_rsa', + ) + ] + assert len(executors) == 1 + assert executors[0].closed is True + assert refresher_calls == ['first', 'first', 'second'] + assert refresher._restart_refresh.is_set() is False + assert callback_calls == [('first', completers[0]), ('second', completers[0])] + assert event_order == [ + 'refresher:first', + 'refresher:first', + 'refresher:second', + 'callback:first', + 'callback:second', + 'close', + ] + + +def test_bg_refresh_wraps_single_callback_callable(monkeypatch, refresher) -> None: + completers: list[SimpleNamespace] = [] + + class FakeCompleter: + tidb_functions = [] + tidb_keywords = [] + + def __init__(self, **options) -> None: + completers.append(self) + + class FakeExecutor: + def __init__(self, *args) -> None: + self.closed = False + + def close(self) -> None: + self.closed = True + + callback = Mock() + + monkeypatch.setattr(completion_refresher, 'SQLCompleter', FakeCompleter) + monkeypatch.setattr(completion_refresher, 'SQLExecute', FakeExecutor) + refresher.refreshers = {} + + refresher._bg_refresh(make_sqlexecute(), callback, {}) + + callback.assert_called_once_with(completers[0]) + + +def test_refresher_decorator_registers_function() -> None: + refreshers: dict[str, object] = {} + + @completion_refresher.refresher('demo', refreshers=refreshers) + def demo(completer, executor) -> None: + return None + + assert refreshers == {'demo': demo} + + +def test_refresh_helpers_delegate_to_completer_and_executor(monkeypatch) -> None: + completer = Mock() + executor = Mock() + executor.dbname = 'current_db' + executor.databases.return_value = ['db1', 'db2'] + executor.table_columns.return_value = iter([('tbl', 'col')]) + executor.foreign_keys.return_value = iter([('tbl', 'col', 'other', 'id')]) + executor.enum_values.return_value = iter([('tbl', 'status', ['open'])]) + executor.users.return_value = iter([('app',)]) + executor.procedures.return_value = iter([('proc',)]) + executor.character_sets.return_value = iter([('utf8mb4',)]) + executor.collations.return_value = iter([('utf8mb4_unicode_ci',)]) + executor.show_candidates.return_value = iter([('FULL TABLES',)]) + + monkeypatch.setattr(completion_refresher, 'COMMANDS', {'\\x': object(), 'help': object()}) + + completion_refresher.refresh_databases(completer, executor) + completion_refresher.refresh_schemata(completer, executor) + completion_refresher.refresh_tables(completer, executor) + completion_refresher.refresh_foreign_keys(completer, executor) + completion_refresher.refresh_enum_values(completer, executor) + completion_refresher.refresh_users(completer, executor) + completion_refresher.refresh_procedures(completer, executor) + completion_refresher.refresh_character_sets(completer, executor) + completion_refresher.refresh_collations(completer, executor) + completion_refresher.refresh_special(completer, executor) + completion_refresher.refresh_show_commands(completer, executor) + + completer.extend_database_names.assert_called_once_with(['db1', 'db2']) + completer.extend_schemata.assert_called_once_with('current_db') + completer.set_dbname.assert_called_once_with('current_db') + completer.extend_relations.assert_called_once_with([('tbl', 'col')], kind='tables') + completer.extend_columns.assert_called_once_with([('tbl', 'col')], kind='tables') + completer.extend_foreign_keys.assert_called_once_with(executor.foreign_keys.return_value) + completer.extend_enum_values.assert_called_once_with(executor.enum_values.return_value) + completer.extend_users.assert_called_once_with(executor.users.return_value) + completer.extend_procedures.assert_called_once_with(executor.procedures.return_value) + completer.extend_character_sets.assert_called_once_with(executor.character_sets.return_value) + completer.extend_collations.assert_called_once_with(executor.collations.return_value) + completer.extend_special_commands.assert_called_once_with(['\\x', 'help']) + completer.extend_show_items.assert_called_once_with(executor.show_candidates.return_value) + + +def test_refresh_functions_extends_tidb_builtins_only_for_tidb() -> None: + completer = Mock() + completer.tidb_functions = ['tidb_func'] + + executor = Mock() + executor.functions.return_value = iter([('func',)]) + executor.server_info = SimpleNamespace(species=completion_refresher.ServerSpecies.TiDB) + + completion_refresher.refresh_functions(completer, executor) + + assert completer.extend_functions.call_args_list == [ + ((executor.functions.return_value,), {}), + ((['tidb_func'],), {'builtin': True}), + ] + + completer.reset_mock() + executor.server_info = SimpleNamespace(species=completion_refresher.ServerSpecies.MySQL) + + completion_refresher.refresh_functions(completer, executor) + + assert completer.extend_functions.call_args_list == [ + ((executor.functions.return_value,), {}), + ] + + completer.reset_mock() + executor.server_info = None + + completion_refresher.refresh_functions(completer, executor) + + assert completer.extend_functions.call_args_list == [ + ((executor.functions.return_value,), {}), + ] + + +def test_refresh_keywords_extends_tidb_keywords_only_for_tidb() -> None: + completer = Mock() + completer.tidb_keywords = ['FLASHBACK'] + + executor = Mock() + executor.server_info = SimpleNamespace(species=completion_refresher.ServerSpecies.TiDB) + + completion_refresher.refresh_keywords(completer, executor) + + completer.extend_keywords.assert_called_once_with(['FLASHBACK'], replace=True) + + completer.reset_mock() + executor.server_info = SimpleNamespace(species=completion_refresher.ServerSpecies.MySQL) + + completion_refresher.refresh_keywords(completer, executor) + + completer.extend_keywords.assert_not_called() + + completer.reset_mock() + executor.server_info = None + + completion_refresher.refresh_keywords(completer, executor) + + completer.extend_keywords.assert_not_called() diff --git a/test/pytests/test_config.py b/test/pytests/test_config.py new file mode 100644 index 00000000..45b26ec4 --- /dev/null +++ b/test/pytests/test_config.py @@ -0,0 +1,373 @@ +# type: ignore + +"""Unit tests for the mycli.config module.""" + +import builtins +from io import BytesIO, StringIO, TextIOWrapper +import logging +import os +import struct +import sys +from tempfile import NamedTemporaryFile +from types import SimpleNamespace + +from configobj import ConfigObj +import pytest + +from mycli import config as config_module +from mycli.config import ( + _remove_pad, + create_default_config, + encrypt_mylogin_cnf, + get_included_configs, + get_mylogin_cnf_path, + log, + open_mylogin_cnf, + read_and_decrypt_mylogin_cnf, + read_config_file, + read_config_files, + str_to_bool, + strip_matching_quotes, + write_default_config, +) +from test.utils import TEMPFILE_PREFIX + +LOGIN_PATH_FILE = os.path.abspath(os.path.join(os.path.dirname(__file__), "../mylogin.cnf")) + + +def open_bmylogin_cnf(name): + """Open contents of *name* in a BytesIO buffer.""" + with open(name, "rb") as f: + buf = BytesIO() + buf.write(f.read()) + return buf + + +def test_read_mylogin_cnf(): + """Tests that a login path file can be read and decrypted.""" + mylogin_cnf = open_mylogin_cnf(LOGIN_PATH_FILE) + + assert isinstance(mylogin_cnf, TextIOWrapper) + + contents = mylogin_cnf.read() + for word in ("[test]", "user", "password", "host", "port"): + assert word in contents + + +def test_decrypt_blank_mylogin_cnf(): + """Test that a blank login path file is handled correctly.""" + mylogin_cnf = read_and_decrypt_mylogin_cnf(BytesIO()) + assert mylogin_cnf is None + + +def test_corrupted_login_key(): + """Test that a corrupted login path key is handled correctly.""" + buf = open_bmylogin_cnf(LOGIN_PATH_FILE) + + # Skip past the unused bytes + buf.seek(4) + + # Write null bytes over half the login key + buf.write(b"\0\0\0\0\0\0\0\0\0\0") + + buf.seek(0) + mylogin_cnf = read_and_decrypt_mylogin_cnf(buf) + + assert mylogin_cnf is None + + +def test_corrupted_pad(): + """Tests that a login path file with a corrupted pad is partially read.""" + buf = open_bmylogin_cnf(LOGIN_PATH_FILE) + + # Skip past the login key + buf.seek(24) + + # Skip option group + len_buf = buf.read(4) + (cipher_len,) = struct.unpack(" None: + fake_logger = SimpleNamespace(parent=SimpleNamespace(name='root'), log=lambda level, message: None) + + log(fake_logger, logging.WARNING, 'root warning') + + assert capsys.readouterr().err == 'root warning\n' + + +def test_read_config_file_from_path_and_parse_error(tmp_path, caplog) -> None: + valid_path = tmp_path / 'valid.cnf' + valid_path.write_text('[main]\ncolor = blue\n', encoding='utf8') + + config = read_config_file(str(valid_path)) + assert config['main']['color'] == 'blue' + + invalid_path = tmp_path / 'invalid.cnf' + invalid_path.write_text('[main\nfoo=bar\n', encoding='utf8') + + with caplog.at_level(logging.WARNING, logger='mycli.config'): + parsed = read_config_file(str(invalid_path)) + assert parsed['foo'] == 'bar' + assert "Unable to parse line 1 of config file" in caplog.text + assert 'Using successfully parsed config values.' in caplog.text + + +def test_read_config_file_permission_error(monkeypatch, caplog) -> None: + def raise_oserror(*_args, **_kwargs): + raise OSError(13, 'denied', '/tmp/test.cnf') + + monkeypatch.setattr(config_module, 'ConfigObj', raise_oserror) + + with caplog.at_level(logging.WARNING, logger='mycli.config'): + assert read_config_file('/tmp/test.cnf') is None + assert "You don't have permission to read config file '/tmp/test.cnf'." in caplog.text + + +def test_get_included_configs_handles_paths_and_errors(tmp_path, monkeypatch) -> None: + include_dir = tmp_path / 'includes' + include_dir.mkdir() + expected = include_dir / 'included.cnf' + expected.write_text('[main]\nfoo = bar\n', encoding='utf8') + (include_dir / 'ignore.txt').write_text('skip', encoding='utf8') + + config_path = tmp_path / 'root.cnf' + config_path.write_text(f'!includedir {include_dir}\n', encoding='utf8') + + assert get_included_configs(BytesIO()) == [] + assert get_included_configs(str(tmp_path / 'missing.cnf')) == [] + assert get_included_configs(str(config_path)) == [str(expected)] + + monkeypatch.setattr(builtins, 'open', lambda *_args, **_kwargs: (_ for _ in ()).throw(PermissionError())) + assert get_included_configs(str(config_path)) == [] + + +def test_read_config_files_merges_includes_and_honors_flags(monkeypatch) -> None: + first_config = ConfigObj({'main': {'color': 'blue'}}) + first_config.filename = 'first.cnf' + included_config = ConfigObj({'main': {'pager': 'less'}}) + included_config.filename = 'included.cnf' + + monkeypatch.setattr(config_module, 'create_default_config', lambda list_values=True: ConfigObj({'default': {'a': '1'}})) + + def fake_read_config_file(filename, list_values=True): + if filename == 'first.cnf': + return first_config + if filename == 'included.cnf': + return included_config + return None + + monkeypatch.setattr(config_module, 'read_config_file', fake_read_config_file) + monkeypatch.setattr(config_module, 'get_included_configs', lambda filename: ['included.cnf'] if filename == 'first.cnf' else []) + + merged = read_config_files(['first.cnf']) + assert merged['default']['a'] == '1' + assert merged['main']['color'] == 'blue' + assert merged['main']['pager'] == 'less' + assert merged.filename == 'included.cnf' + + ignored_defaults = read_config_files(['first.cnf'], ignore_package_defaults=True) + assert 'default' not in ignored_defaults + assert ignored_defaults['main']['color'] == 'blue' + + untouched = read_config_files(['first.cnf'], ignore_user_options=True) + assert untouched == ConfigObj({'default': {'a': '1'}}) + assert 'main' not in untouched + + +def test_create_and_write_default_config(tmp_path) -> None: + default_config = create_default_config() + assert 'main' in default_config + + destination = tmp_path / 'myclirc' + write_default_config(str(destination)) + written = destination.read_text(encoding='utf8') + assert '[main]' in written + + destination.write_text('custom', encoding='utf8') + write_default_config(str(destination)) + assert destination.read_text(encoding='utf8') == 'custom' + + write_default_config(str(destination), overwrite=True) + assert '[main]' in destination.read_text(encoding='utf8') + + +def test_get_mylogin_cnf_path_returns_none_for_missing_file(monkeypatch, tmp_path) -> None: + monkeypatch.setenv('MYSQL_TEST_LOGIN_FILE', str(tmp_path / 'missing.mylogin.cnf')) + + assert get_mylogin_cnf_path() is None + + +def test_open_mylogin_cnf_error_paths(monkeypatch, tmp_path, caplog) -> None: + with caplog.at_level(logging.ERROR, logger='mycli.config'): + assert open_mylogin_cnf(str(tmp_path / 'missing.mylogin.cnf')) is None + assert 'Unable to open login path file.' in caplog.text + + caplog.clear() + existing = tmp_path / 'present.mylogin.cnf' + existing.write_bytes(b'not-used') + monkeypatch.setattr(config_module, 'read_and_decrypt_mylogin_cnf', lambda f: None) + + with caplog.at_level(logging.ERROR, logger='mycli.config'): + assert open_mylogin_cnf(str(existing)) is None + assert 'Unable to read login path file.' in caplog.text + + +def test_encrypt_mylogin_cnf_round_trip() -> None: + plaintext = StringIO('[client]\nuser=test\npassword=secret\n') + + encrypted = encrypt_mylogin_cnf(plaintext) + decrypted = read_and_decrypt_mylogin_cnf(encrypted) + + assert isinstance(encrypted, BytesIO) + assert decrypted.read().decode('utf8') == '[client]\nuser=test\npassword=secret\n' + + +def test_read_and_decrypt_mylogin_cnf_error_branches(caplog) -> None: + incomplete_key = BytesIO(struct.pack('i', 0) + b'a') + with caplog.at_level(logging.ERROR, logger='mycli.config'): + assert read_and_decrypt_mylogin_cnf(incomplete_key) is None + assert 'Unable to generate login path AES key.' in caplog.text + + caplog.clear() + no_payload = BytesIO(struct.pack('i', 0) + b'0123456789abcdefghij') + with caplog.at_level(logging.ERROR, logger='mycli.config'): + assert read_and_decrypt_mylogin_cnf(no_payload) is None + assert 'No data successfully decrypted from login path file.' in caplog.text + + +def test_remove_pad_valid_and_invalid_cases(caplog) -> None: + assert _remove_pad(b'hello\x03\x03\x03') == b'hello' + + with caplog.at_level(logging.WARNING, logger='mycli.config'): + assert _remove_pad(b'') is False + assert 'Unable to remove pad.' in caplog.text + + caplog.clear() + with caplog.at_level(logging.WARNING, logger='mycli.config'): + assert _remove_pad(b'hello\x02\x03') is False + assert 'Invalid pad found in login path file.' in caplog.text + + +def test_strip_quotes_with_matching_quotes(): + """Test that a string with matching quotes is unquoted.""" + + s = "May the force be with you." + assert s == strip_matching_quotes(f'"{s}"') + assert s == strip_matching_quotes(f"'{s}'") + + +def test_strip_quotes_with_unmatching_quotes(): + """Test that a string with unmatching quotes is not unquoted.""" + + s = "May the force be with you." + assert '"' + s == strip_matching_quotes(f'"{s}') + assert s + "'" == strip_matching_quotes(f"{s}'") + + +def test_strip_quotes_with_empty_string(): + """Test that an empty string is handled during unquoting.""" + + assert "" == strip_matching_quotes("") + + +def test_strip_quotes_with_none(): + """Test that None is handled during unquoting.""" + + assert None is strip_matching_quotes(None) + + +def test_strip_quotes_with_quotes(): + """Test that strings with quotes in them are handled during unquoting.""" + + s1 = 'Darth Vader said, "Luke, I am your father."' + assert s1 == strip_matching_quotes(s1) + + s2 = '"Darth Vader said, "Luke, I am your father.""' + assert s2[1:-1] == strip_matching_quotes(s2) diff --git a/test/pytests/test_delimitercommand.py b/test/pytests/test_delimitercommand.py new file mode 100644 index 00000000..aefd8e40 --- /dev/null +++ b/test/pytests/test_delimitercommand.py @@ -0,0 +1,100 @@ +# type: ignore + +from __future__ import annotations + +from mycli.packages.special.delimitercommand import DelimiterCommand + + +def test_delimiter_command_defaults_to_semicolon() -> None: + command = DelimiterCommand() + + assert command.current == ';' + + +def test_set_uses_first_argument_token_and_updates_current_delimiter() -> None: + command = DelimiterCommand() + + result = command.set('$$ select 1 $$') + + assert result[0].status == 'Changed delimiter to $$' + assert command.current == '$$' + + +def test_set_rejects_missing_argument() -> None: + command = DelimiterCommand() + + result = command.set('') + + assert result[0].status == 'Missing required argument, delimiter' + assert command.current == ';' + + +def test_set_rejects_delimiter_keyword_case_insensitively() -> None: + command = DelimiterCommand() + + result = command.set('Delimiter') + + assert result[0].status == 'Invalid delimiter "delimiter"' + assert command.current == ';' + + +def test_queries_iter_preserves_statement_text_for_multi_character_delimiter() -> None: + command = DelimiterCommand() + command.set('end') + + assert list(command.queries_iter('delete 1end')) == ['delete 1'] + + +def test_queries_iter_with_custom_delimiter_preserves_semicolons_inside_statement() -> None: + command = DelimiterCommand() + command.set('$$') + + assert list(command.queries_iter('select 1; select 2$$ select 3$$')) == [ + 'select 1; select 2', + 'select 3', + ] + + +def test_split_handles_placeholder_collision_in_original_sql() -> None: + command = DelimiterCommand() + command.set('$$') + + assert command._split('select \ufffc1; select 2$$ select 3$$') == [ + 'select \ufffc1; select 2$$', + 'select 3$$', + ] + + +def test_queries_iter_resplits_remaining_input_after_delimiter_change() -> None: + command = DelimiterCommand() + queries = command.queries_iter('select 1; delimiter $$ select 2$$ select 3$$') + + assert next(queries) == 'select 1' + assert next(queries) == 'delimiter $$ select 2$$ select 3$$' + + command.set('$$') + + assert list(queries) == ['select 2', 'select 3'] + + +def test_queries_iter_reappends_old_trailing_delimiter_before_resplitting(monkeypatch) -> None: + command = DelimiterCommand() + command._delimiter = ';;' + split_calls: list[str] = [] + + def fake_split(sql: str) -> list[str]: + split_calls.append(sql) + if len(split_calls) == 1: + return ['delimiter $$;;', 'select 2$$'] + return ['ignored', 'select 2'] + + monkeypatch.setattr(command, '_split', fake_split) + + queries = command.queries_iter('ignored') + + assert next(queries) == 'delimiter $$' + + command.set('$$') + + assert list(queries) == ['select 2'] + assert split_calls == ['ignored', 'delimiter $$ select 2$$;;'] diff --git a/test/pytests/test_favoritequeries.py b/test/pytests/test_favoritequeries.py new file mode 100644 index 00000000..c3c3aee7 --- /dev/null +++ b/test/pytests/test_favoritequeries.py @@ -0,0 +1,100 @@ +from collections.abc import Mapping + +from mycli.packages.special.favoritequeries import FavoriteQueries + + +class DummyConfig(dict): + def __init__(self, initial: Mapping[str, object] | None = None) -> None: + super().__init__(initial or {}) + self.encoding: str | None = None + self.write_calls = 0 + + def write(self) -> None: + self.write_calls += 1 + + +def test_from_config_returns_instance_with_same_config() -> None: + config = DummyConfig() + + favorites = FavoriteQueries.from_config(config) + + assert isinstance(favorites, FavoriteQueries) + assert favorites.config is config + + +def test_list_and_get_use_favorite_queries_section() -> None: + config = DummyConfig({ + 'favorite_queries': { + 'daily': 'select 1', + 'weekly': 'select 2', + }, + }) + favorites = FavoriteQueries(config) + + assert favorites.list() == ['daily', 'weekly'] + assert favorites.get('daily') == 'select 1' + assert favorites.get('missing') is None + + +def test_list_returns_empty_list_when_section_is_missing() -> None: + favorites = FavoriteQueries(DummyConfig()) + + assert favorites.list() == [] + + +def test_save_creates_section_sets_encoding_and_writes_config() -> None: + config = DummyConfig() + favorites = FavoriteQueries(config) + + favorites.save('demo', 'select 1') + + assert config.encoding == 'utf-8' + assert config == {'favorite_queries': {'demo': 'select 1'}} + assert config.write_calls == 1 + + +def test_save_updates_existing_section_and_writes_config() -> None: + config = DummyConfig({'favorite_queries': {'demo': 'select 1'}}) + favorites = FavoriteQueries(config) + + favorites.save('report', 'select 2') + + assert config.encoding == 'utf-8' + assert config['favorite_queries'] == { + 'demo': 'select 1', + 'report': 'select 2', + } + assert config.write_calls == 1 + + +def test_delete_removes_existing_favorite_and_writes_config() -> None: + config = DummyConfig({'favorite_queries': {'demo': 'select 1'}}) + favorites = FavoriteQueries(config) + + result = favorites.delete('demo') + + assert result == 'demo: Deleted.' + assert config['favorite_queries'] == {} + assert config.write_calls == 1 + + +def test_delete_returns_not_found_without_writing_config() -> None: + config = DummyConfig({'favorite_queries': {'demo': 'select 1'}}) + favorites = FavoriteQueries(config) + + result = favorites.delete('missing') + + assert result == 'missing: Not Found.' + assert config['favorite_queries'] == {'demo': 'select 1'} + assert config.write_calls == 0 + + +def test_delete_returns_not_found_when_section_is_missing() -> None: + config = DummyConfig() + favorites = FavoriteQueries(config) + + result = favorites.delete('missing') + + assert result == 'missing: Not Found.' + assert config == {} + assert config.write_calls == 0 diff --git a/test/pytests/test_filepaths.py b/test/pytests/test_filepaths.py new file mode 100644 index 00000000..3fb8e1ff --- /dev/null +++ b/test/pytests/test_filepaths.py @@ -0,0 +1,126 @@ +import importlib.util +import os +from pathlib import Path +import platform +import sys +from types import ModuleType +from typing import Any + +import pytest + +from mycli.packages import filepaths + + +def load_filepaths_variant( + monkeypatch: pytest.MonkeyPatch, + *, + os_name: str, + system_name: str, +) -> ModuleType: + module_path = str(Path(filepaths.__file__).resolve()) + monkeypatch.setattr(os, 'name', os_name, raising=False) + monkeypatch.setattr(platform, 'system', lambda: system_name) + module_name = f'filepaths_variant_{os_name}_{system_name}' + spec = importlib.util.spec_from_file_location(module_name, module_path) + assert spec is not None + assert spec.loader is not None + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + return module + + +def test_default_socket_dirs_import_variants(monkeypatch: pytest.MonkeyPatch) -> None: + darwin = load_filepaths_variant(monkeypatch, os_name='posix', system_name='Darwin') + assert darwin.DEFAULT_SOCKET_DIRS == ['/tmp'] + + linux = load_filepaths_variant(monkeypatch, os_name='posix', system_name='Linux') + assert linux.DEFAULT_SOCKET_DIRS == ['/var/run', '/var/lib'] + + windows = load_filepaths_variant(monkeypatch, os_name='nt', system_name='Windows') + assert windows.DEFAULT_SOCKET_DIRS == [] + + +def test_list_path_lists_sql_files_and_directories(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.chdir(tmp_path) + (tmp_path / '.hidden.sql').write_text('select 1\n', encoding='utf-8') + (tmp_path / 'visible.SQL').write_text('select 1\n', encoding='utf-8') + (tmp_path / 'notes.txt').write_text('ignored\n', encoding='utf-8') + (tmp_path / 'folder').mkdir() + + assert filepaths.list_path(str(tmp_path)) == ['visible.SQL', 'folder/'] + assert filepaths.list_path(str(tmp_path / 'missing')) == [] + + +def test_complete_path_and_parse_path() -> None: + assert filepaths.complete_path('abc', '') == 'abc' + assert filepaths.complete_path('abcdef', 'abc') == 'abcdef' + assert filepaths.complete_path('docs', '~') == os.path.join('~', 'docs') + assert filepaths.complete_path('docs', 'other') == '' + + assert filepaths.parse_path('') == ('', '', 0) + assert filepaths.parse_path('/tmp/query.sql') == ('/tmp', 'query.sql', -9) + assert filepaths.parse_path('/tmp/dir/') == ('/tmp/dir', '', 0) + + +def test_suggest_path_branches(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.chdir(tmp_path) + (tmp_path / 'query.sql').write_text('select 1\n', encoding='utf-8') + (tmp_path / 'subdir').mkdir() + + assert filepaths.suggest_path('') == [ + os.path.abspath(os.sep), + '~', + os.curdir, + os.pardir, + 'query.sql', + 'subdir/', + ] + + assert filepaths.suggest_path('relative') == ['query.sql', 'subdir/'] + + home = tmp_path / 'home' + home.mkdir() + (home / 'from_home.sql').write_text('select 1\n', encoding='utf-8') + monkeypatch.setattr(os.path, 'expanduser', lambda path: str(home)) + assert filepaths.suggest_path('~/f') == ['from_home.sql'] + + nested = tmp_path / 'nested' + nested.mkdir() + (nested / 'inside.sql').write_text('select 1\n', encoding='utf-8') + assert filepaths.suggest_path(str(nested / 'missing.sql')) == ['inside.sql'] + + +def test_dir_path_exists(tmp_path: Path) -> None: + existing = tmp_path / 'logs' / 'mycli.log' + existing.parent.mkdir() + assert filepaths.dir_path_exists(str(existing)) is True + assert filepaths.dir_path_exists(str(tmp_path / 'missing' / 'mycli.log')) is False + + +def test_guess_socket_location_returns_matching_socket(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(filepaths, 'DEFAULT_SOCKET_DIRS', ['/a', '/b']) + monkeypatch.setattr(filepaths.os.path, 'exists', lambda path: path == '/b') + monkeypatch.setattr( + filepaths.os, + 'walk', + lambda directory, topdown=True: iter([ + ('/b', ['mysql-data', 'other'], ['mysqlx.sock', 'mysql.socket']), + ]), + ) + assert filepaths.guess_socket_location() == '/b/mysql.socket' + + +def test_guess_socket_location_prunes_dirs_and_returns_none(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(filepaths, 'DEFAULT_SOCKET_DIRS', ['/a']) + monkeypatch.setattr(filepaths.os.path, 'exists', lambda path: True) + walked_dirs: list[list[str]] = [] + + def fake_walk(directory: str, topdown: bool = True) -> Any: + dirs = ['mysql-data', 'tmp', 'mysqlx', 'other'] + walked_dirs.append(dirs) + yield (directory, dirs, ['mysqlx.sock', 'readme.txt']) + + monkeypatch.setattr(filepaths.os, 'walk', fake_walk) + assert filepaths.guess_socket_location() is None + assert walked_dirs[0] == ['mysql-data', 'mysqlx'] diff --git a/test/pytests/test_hybrid_redirection.py b/test/pytests/test_hybrid_redirection.py new file mode 100644 index 00000000..1b6d79b0 --- /dev/null +++ b/test/pytests/test_hybrid_redirection.py @@ -0,0 +1,135 @@ +from typing import Generator + +import pytest +import sqlglot + +from mycli.packages import hybrid_redirection + + +def tokenize(command: str) -> list[sqlglot.Token]: + return sqlglot.tokenize(command) + + +@pytest.fixture() +def reset_hybrid_redirection(monkeypatch) -> Generator[None, None, None]: + monkeypatch.setattr(hybrid_redirection, 'WIN', False) + original_delimiter = hybrid_redirection.delimiter_command.current + hybrid_redirection.delimiter_command._delimiter = ';' + yield + hybrid_redirection.delimiter_command._delimiter = original_delimiter + + +def test_find_token_indices_tracks_true_dollars_and_operators() -> None: + tokens = tokenize('select 1 $| cat $>> out.txt') + + assert hybrid_redirection.find_token_indices(tokens) == { + 'raw_dollar': [2, 5], + 'true_dollar': [2, 5], + 'angle_bracket': [6], + 'pipe': [3], + } + + +# todo there are still corner cases combining custom delimiters and redirection +def test_find_sql_part_handles_valid_parse_custom_delimiter_and_invalid_sql(reset_hybrid_redirection) -> None: + hybrid_redirection.delimiter_command._delimiter = '$$' + valid_tokens = tokenize('select 1 $$ $> out.txt') + assert hybrid_redirection.find_sql_part('select 1 $$ $> out.txt', valid_tokens, [3]) == 'select 1' + + invalid_tokens = tokenize('select from $> out.txt') + assert hybrid_redirection.find_sql_part('select from $> out.txt', invalid_tokens, [2]) == '' + + multiple_tokens = tokenize('select 1; select 2 $> out.txt') + assert hybrid_redirection.find_sql_part('select 1; select 2 $> out.txt', multiple_tokens, [5]) == '' + + +def test_find_command_and_file_tokens_extract_expected_parts() -> None: + tokens = tokenize('select 1 $| cat $>> out.txt') + indices = hybrid_redirection.find_token_indices(tokens) + + file_tokens, file_index, operator = hybrid_redirection.find_file_tokens(tokens, indices['angle_bracket']) + command_tokens = hybrid_redirection.find_command_tokens(tokens[0:file_index], indices['true_dollar']) + + assert operator == '>>' + assert file_index == 6 + assert hybrid_redirection.assemble_tokens(file_tokens) == 'out.txt' + assert hybrid_redirection.assemble_tokens(command_tokens) == 'cat' + + +def test_find_file_tokens_returns_empty_when_no_redirect_file() -> None: + tokens = tokenize('select 1 $| cat') + + file_tokens, file_index, operator = hybrid_redirection.find_file_tokens(tokens, []) + + assert file_tokens == [] + assert file_index == len(tokens) + assert operator is None + + +def test_assemble_tokens_quotes_identifier_and_string() -> None: + identifier_tokens = tokenize('echo hi $> "quoted.txt"')[4:] + string_tokens = tokenize("echo hi $| 'printf'")[4:] + + assert hybrid_redirection.assemble_tokens(identifier_tokens) == '"quoted.txt"' + assert hybrid_redirection.assemble_tokens(string_tokens) == "'printf'" + + +@pytest.mark.parametrize( + ('file_part', 'command_part', 'expected'), + [ + ('two words.txt', None, True), + ('bad>file.txt', None, True), + (None, None, True), + ('out.txt', None, False), + (None, 'cat', False), + ], +) +def test_invalid_shell_part(file_part: str | None, command_part: str | None, expected: bool) -> None: + assert hybrid_redirection.invalid_shell_part(file_part, command_part) is expected + + +def test_get_redirect_components_valid_paths_and_logging() -> None: + assert hybrid_redirection.get_redirect_components('select 1 $>> out.txt') == ( + 'select 1', + None, + '>>', + 'out.txt', + ) + assert hybrid_redirection.get_redirect_components('select 1 $| cat $> out.txt') == ( + 'select 1', + 'cat', + '>', + 'out.txt', + ) + + +def test_get_redirect_components_returns_none_on_token_error(monkeypatch) -> None: + monkeypatch.setattr( + hybrid_redirection.sqlglot, 'tokenize', lambda command: (_ for _ in ()).throw(sqlglot.errors.TokenError('bad token')) + ) + + assert hybrid_redirection.get_redirect_components('select 1 $> out.txt') == (None, None, None, None) + + +def test_get_redirect_components_rejects_invalid_forms() -> None: + assert hybrid_redirection.get_redirect_components('select 1') == (None, None, None, None) + assert hybrid_redirection.get_redirect_components('select 1 $> out.txt $> other.txt') == (None, None, None, None) + assert hybrid_redirection.get_redirect_components('select 1 $> out.txt $| cat') == (None, None, None, None) + assert hybrid_redirection.get_redirect_components('select from $> out.txt') == (None, None, None, None) + assert hybrid_redirection.get_redirect_components('select 1 $> "two words.txt"') == (None, None, None, None) + + +def test_get_redirect_components_rejects_multiple_pipes_on_windows(monkeypatch) -> None: + monkeypatch.setattr(hybrid_redirection, 'WIN', True) + + assert hybrid_redirection.get_redirect_components('select 1 $| cat $| more') == ( + None, + None, + None, + None, + ) + + +def test_is_redirect_command_reflects_component_parsing() -> None: + assert hybrid_redirection.is_redirect_command('select 1 $| cat') is True + assert hybrid_redirection.is_redirect_command('select 1') is False diff --git a/test/pytests/test_interactive_utils.py b/test/pytests/test_interactive_utils.py new file mode 100644 index 00000000..66182c93 --- /dev/null +++ b/test/pytests/test_interactive_utils.py @@ -0,0 +1,169 @@ +from types import SimpleNamespace + +import click +import pytest + +from mycli.packages import interactive_utils + + +def test_confirm_bool_param_type_converts_bool_and_strings() -> None: + boolean_type = interactive_utils.ConfirmBoolParamType() + + assert boolean_type.convert(True, None, None) is True + assert boolean_type.convert(False, None, None) is False + assert boolean_type.convert('YES', None, None) is True + assert boolean_type.convert('y', None, None) is True + assert boolean_type.convert('NO', None, None) is False + assert boolean_type.convert('n', None, None) is False + assert repr(boolean_type) == 'BOOL' + + +def test_confirm_bool_param_type_rejects_invalid_string() -> None: + boolean_type = interactive_utils.ConfirmBoolParamType() + + with pytest.raises(click.BadParameter, match='maybe is not a valid boolean'): + boolean_type.convert('maybe', None, None) + + +def test_confirm_destructive_query_returns_none_when_not_destructive(monkeypatch: pytest.MonkeyPatch) -> None: + prompt_called = False + destructive_calls: list[tuple[list[str], str]] = [] + + def fake_prompt(*args: object, **kwargs: object) -> bool: + nonlocal prompt_called + prompt_called = True + return True + + def fake_is_destructive(keywords: list[str], query: str) -> bool: + destructive_calls.append((keywords, query)) + return False + + monkeypatch.setattr(interactive_utils, 'is_destructive', fake_is_destructive) + monkeypatch.setattr(interactive_utils, 'prompt', fake_prompt) + monkeypatch.setattr(interactive_utils.sys, 'stdin', SimpleNamespace(isatty=lambda: True)) + + keywords = ['drop'] + query = 'select 1;' + assert interactive_utils.confirm_destructive_query(keywords, query) is None + assert destructive_calls == [(keywords, query)] + assert prompt_called is False + + +def test_confirm_destructive_query_returns_none_without_tty(monkeypatch: pytest.MonkeyPatch) -> None: + prompt_called = False + + def fake_prompt(*args: object, **kwargs: object) -> bool: + nonlocal prompt_called + prompt_called = True + return True + + monkeypatch.setattr(interactive_utils, 'is_destructive', lambda keywords, query: True) + monkeypatch.setattr(interactive_utils, 'prompt', fake_prompt) + monkeypatch.setattr(interactive_utils.sys, 'stdin', SimpleNamespace(isatty=lambda: False)) + + keywords = ['drop'] + sql = 'drop database foo;' + assert interactive_utils.confirm_destructive_query(keywords, sql) is None + assert prompt_called is False + + +def test_confirm_destructive_query_prompts_and_returns_user_choice(monkeypatch: pytest.MonkeyPatch) -> None: + prompt_calls: list[tuple[tuple[object, ...], dict[str, object]]] = [] + destructive_calls: list[tuple[list[str], str]] = [] + + def fake_prompt(*args: object, **kwargs: object) -> bool: + prompt_calls.append((args, dict(kwargs))) + return True + + def fake_is_destructive(keywords: list[str], query: str) -> bool: + destructive_calls.append((keywords, query)) + return True + + monkeypatch.setattr(interactive_utils, 'is_destructive', fake_is_destructive) + monkeypatch.setattr(interactive_utils, 'prompt', fake_prompt) + monkeypatch.setattr(interactive_utils.sys, 'stdin', SimpleNamespace(isatty=lambda: True)) + + keywords = ['drop'] + query = 'drop database foo;' + result = interactive_utils.confirm_destructive_query(keywords, query) + + assert result is True + assert destructive_calls == [(keywords, query)] + assert prompt_calls == [ + ( + ("You're about to run a destructive command.\nDo you want to proceed? (y/n)",), + {'type': interactive_utils.BOOLEAN_TYPE}, + ) + ] + + +def test_confirm_destructive_query_returns_false_when_user_rejects(monkeypatch: pytest.MonkeyPatch) -> None: + prompt_calls: list[tuple[tuple[object, ...], dict[str, object]]] = [] + destructive_calls: list[tuple[list[str], str]] = [] + + def fake_prompt(*args: object, **kwargs: object) -> bool: + prompt_calls.append((args, dict(kwargs))) + return False + + def fake_is_destructive(keywords: list[str], query: str) -> bool: + destructive_calls.append((keywords, query)) + return True + + monkeypatch.setattr(interactive_utils, 'is_destructive', fake_is_destructive) + monkeypatch.setattr(interactive_utils, 'prompt', fake_prompt) + monkeypatch.setattr(interactive_utils.sys, 'stdin', SimpleNamespace(isatty=lambda: True)) + + keywords = ['drop'] + query = 'drop database foo;' + assert interactive_utils.confirm_destructive_query(keywords, query) is False + assert destructive_calls == [(keywords, query)] + assert prompt_calls == [ + ( + ("You're about to run a destructive command.\nDo you want to proceed? (y/n)",), + {'type': interactive_utils.BOOLEAN_TYPE}, + ) + ] + + +def test_confirm_returns_false_on_click_abort(monkeypatch: pytest.MonkeyPatch) -> None: + def fake_confirm(*args: object, **kwargs: object) -> bool: + raise click.Abort() + + monkeypatch.setattr(click, 'confirm', fake_confirm) + + assert interactive_utils.confirm('continue?') is False + + +def test_confirm_delegates_to_click_confirm(monkeypatch: pytest.MonkeyPatch) -> None: + calls: list[tuple[tuple[object, ...], dict[str, object]]] = [] + + def fake_confirm(*args: object, **kwargs: object) -> bool: + calls.append((args, dict(kwargs))) + return True + + monkeypatch.setattr(click, 'confirm', fake_confirm) + + assert interactive_utils.confirm('continue?', default=True) is True + assert calls == [(('continue?',), {'default': True})] + + +def test_prompt_returns_false_on_click_abort(monkeypatch: pytest.MonkeyPatch) -> None: + def fake_prompt(*args: object, **kwargs: object) -> bool: + raise click.Abort() + + monkeypatch.setattr(click, 'prompt', fake_prompt) + + assert interactive_utils.prompt('continue?') is False + + +def test_prompt_delegates_to_click_prompt(monkeypatch: pytest.MonkeyPatch) -> None: + calls: list[tuple[tuple[object, ...], dict[str, object]]] = [] + + def fake_prompt(*args: object, **kwargs: object) -> bool: + calls.append((args, dict(kwargs))) + return True + + monkeypatch.setattr(click, 'prompt', fake_prompt) + + assert interactive_utils.prompt('continue?', type=interactive_utils.BOOLEAN_TYPE) is True + assert calls == [(('continue?',), {'type': interactive_utils.BOOLEAN_TYPE})] diff --git a/test/pytests/test_key_binding_utils.py b/test/pytests/test_key_binding_utils.py new file mode 100644 index 00000000..bbb3d619 --- /dev/null +++ b/test/pytests/test_key_binding_utils.py @@ -0,0 +1,228 @@ +import datetime +from typing import Any, cast + +import pytest + +from mycli.packages import key_binding_utils + + +class FakeSQLExecute: + def __init__(self, now_value: datetime.datetime) -> None: + self.now_value = now_value + + def now(self) -> datetime.datetime: + return self.now_value + + +class FakePromptSession: + def __init__(self, responses: list[object]) -> None: + self.responses = list(responses) + self.prompt_calls: list[dict[str, Any]] = [] + + def prompt(self, *, default: str, inputhook: Any, message: Any) -> str: + self.prompt_calls.append({ + 'default': default, + 'inputhook': inputhook, + 'message': message, + }) + response = self.responses.pop(0) + if isinstance(response, BaseException): + raise response + return cast(str, response) + + +class FakeMyCli: + def __init__( + self, + *, + prompt_session: FakePromptSession | None = None, + last_query: str = 'last query', + ) -> None: + self.prompt_session = prompt_session + self.last_query = last_query + self.toolbar_error_message: str | None = None + + def get_last_query(self) -> str: + return self.last_query + + +def test_server_date_returns_quoted_and_unquoted_values() -> None: + sqlexecute = FakeSQLExecute(datetime.datetime(2026, 4, 3, 14, 5, 6)) + + assert key_binding_utils.server_date(cast(Any, sqlexecute)) == '2026-04-03' + assert key_binding_utils.server_date(cast(Any, sqlexecute), quoted=True) == "'2026-04-03'" + + +def test_server_datetime_returns_quoted_and_unquoted_values() -> None: + sqlexecute = FakeSQLExecute(datetime.datetime(2026, 4, 3, 14, 5, 6)) + + assert key_binding_utils.server_datetime(cast(Any, sqlexecute)) == '2026-04-03 14:05:06' + assert key_binding_utils.server_datetime(cast(Any, sqlexecute), quoted=True) == "'2026-04-03 14:05:06'" + + +def test_prettify_statement(): + statement = 'SELECT 1' + mycli = FakeMyCli() + pretty_statement = key_binding_utils.handle_prettify_binding(cast(Any, mycli), statement) + assert pretty_statement == 'SELECT\n 1;' + + +def test_unprettify_statement(): + statement = 'SELECT\n 1' + mycli = FakeMyCli() + unpretty_statement = key_binding_utils.handle_unprettify_binding(cast(Any, mycli), statement) + assert unpretty_statement == 'SELECT 1;' + + +def test_handle_editor_command_returns_text_unchanged_when_not_editor_command(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(key_binding_utils.special, 'editor_command', lambda text: False) + + mycli = FakeMyCli() + + assert key_binding_utils.handle_editor_command(cast(Any, mycli), 'select 1', None, lambda: 'loaded') == 'select 1' + + +def test_handle_editor_command_opens_editor_reprompts_after_keyboard_interrupt_and_returns_text(monkeypatch: pytest.MonkeyPatch) -> None: + prompt_session = FakePromptSession([KeyboardInterrupt(), 'edited sql']) + mycli = FakeMyCli(prompt_session=prompt_session) + open_calls: list[dict[str, str]] = [] + + def inputhook(*args: object, **kwargs: object) -> None: + return None + + def loaded_message_fn() -> str: + return 'loaded' + + def open_external_editor(*, filename: str | None, sql: str) -> tuple[str, str | None]: + open_calls.append({'filename': cast(str, filename), 'sql': sql}) + return 'SELECT 1', None + + monkeypatch.setattr(key_binding_utils, 'PromptSession', FakePromptSession) + monkeypatch.setattr(key_binding_utils.special, 'editor_command', lambda text: text in {'\\e', ''}) + monkeypatch.setattr(key_binding_utils.special, 'get_filename', lambda text: 'query.sql') + monkeypatch.setattr(key_binding_utils.special, 'get_editor_query', lambda text: '' if text == '\\e' else None) + monkeypatch.setattr( + key_binding_utils.special, + 'open_external_editor', + open_external_editor, + ) + + result = key_binding_utils.handle_editor_command(cast(Any, mycli), '\\e', inputhook, loaded_message_fn) + + assert result == 'edited sql' + assert open_calls == [{'filename': 'query.sql', 'sql': 'last query'}] + assert prompt_session.prompt_calls == [ + {'default': 'SELECT 1', 'inputhook': inputhook, 'message': loaded_message_fn}, + {'default': '', 'inputhook': inputhook, 'message': loaded_message_fn}, + ] + + +def test_handle_editor_command_uses_explicit_editor_query_and_raises_on_editor_error(monkeypatch: pytest.MonkeyPatch) -> None: + mycli = FakeMyCli(prompt_session=FakePromptSession([])) + + monkeypatch.setattr(key_binding_utils.special, 'editor_command', lambda text: True) + monkeypatch.setattr(key_binding_utils.special, 'get_filename', lambda text: 'query.sql') + monkeypatch.setattr(key_binding_utils.special, 'get_editor_query', lambda text: 'select from text') + monkeypatch.setattr( + key_binding_utils.special, + 'open_external_editor', + lambda *, filename, sql: ('', 'editor failed'), + ) + + with pytest.raises(RuntimeError, match='editor failed'): + key_binding_utils.handle_editor_command(cast(Any, mycli), '\\eselect 1', None, lambda: 'loaded') + + +def test_handle_clip_command_returns_false_when_not_clip_command(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(key_binding_utils.special, 'clip_command', lambda text: False) + + mycli = FakeMyCli() + + assert key_binding_utils.handle_clip_command(cast(Any, mycli), 'select 1') is False + + +def test_handle_clip_command_copies_explicit_query(monkeypatch: pytest.MonkeyPatch) -> None: + clipboard_calls: list[str] = [] + + def copy_query_to_clipboard(*, sql: str) -> None: + clipboard_calls.append(sql) + + monkeypatch.setattr(key_binding_utils.special, 'clip_command', lambda text: True) + monkeypatch.setattr(key_binding_utils.special, 'get_clip_query', lambda text: 'select 1') + monkeypatch.setattr( + key_binding_utils.special, + 'copy_query_to_clipboard', + copy_query_to_clipboard, + ) + + mycli = FakeMyCli() + + assert key_binding_utils.handle_clip_command(cast(Any, mycli), '\\clip select 1') is True + assert clipboard_calls == ['select 1'] + + +def test_handle_clip_command_uses_last_query_and_raises_on_clipboard_error(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(key_binding_utils.special, 'clip_command', lambda text: True) + monkeypatch.setattr(key_binding_utils.special, 'get_clip_query', lambda text: '') + monkeypatch.setattr( + key_binding_utils.special, + 'copy_query_to_clipboard', + lambda *, sql: 'clipboard failed', + ) + + mycli = FakeMyCli() + + with pytest.raises(RuntimeError, match='clipboard failed'): + key_binding_utils.handle_clip_command(cast(Any, mycli), '\\clip') + + +def test_prettify_statement_returns_empty_string_for_empty_input() -> None: + mycli = FakeMyCli() + assert key_binding_utils.handle_prettify_binding(cast(Any, mycli), '') == '' + + +def test_unprettify_statement_returns_empty_string_for_empty_input() -> None: + mycli = FakeMyCli() + assert key_binding_utils.handle_unprettify_binding(cast(Any, mycli), '') == '' + + +@pytest.mark.parametrize( + ('handler_name', 'text'), + [ + ('handle_prettify_binding', 'SELECT 1;'), + ('handle_unprettify_binding', 'SELECT 1;'), + ], +) +def test_prettify_helpers_fall_back_to_input_without_trailing_semicolon_on_parse_error( + monkeypatch: pytest.MonkeyPatch, + handler_name: str, + text: str, +) -> None: + monkeypatch.setattr(key_binding_utils.sqlglot, 'parse', lambda *_args, **_kwargs: (_ for _ in ()).throw(ValueError('bad sql'))) + + handler = getattr(key_binding_utils, handler_name) + + mycli = FakeMyCli() + + assert handler(cast(Any, mycli), text) == 'SELECT 1' + + +@pytest.mark.parametrize( + ('handler_name', 'text'), + [ + ('handle_prettify_binding', 'SELECT 1; SELECT 2;'), + ('handle_unprettify_binding', 'SELECT 1; SELECT 2;'), + ], +) +def test_prettify_helpers_fall_back_when_parse_returns_multiple_statements( + monkeypatch: pytest.MonkeyPatch, + handler_name: str, + text: str, +) -> None: + monkeypatch.setattr(key_binding_utils.sqlglot, 'parse', lambda *_args, **_kwargs: [object(), object()]) + + handler = getattr(key_binding_utils, handler_name) + + mycli = FakeMyCli() + + assert handler(cast(Any, mycli), text) == 'SELECT 1; SELECT 2' diff --git a/test/pytests/test_key_bindings.py b/test/pytests/test_key_bindings.py new file mode 100644 index 00000000..dd169d09 --- /dev/null +++ b/test/pytests/test_key_bindings.py @@ -0,0 +1,681 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from types import SimpleNamespace +from typing import Any, Callable, cast + +import prompt_toolkit +from prompt_toolkit.enums import EditingMode +from prompt_toolkit.key_binding.key_processor import KeyPressEvent +from prompt_toolkit.keys import Keys +from prompt_toolkit.layout.controls import BufferControl, SearchBufferControl +from prompt_toolkit.selection import SelectionType +import pytest + +from mycli import key_bindings + + +@dataclass +class DummyKeysConfig: + behaviors: dict[str, list[str]] = field(default_factory=dict) + options: dict[str, str] = field(default_factory=dict) + + def as_list(self, name: str) -> list[str]: + return self.behaviors[name] + + def get(self, name: str, default: str | None = None) -> str | None: + return self.options.get(name, default) + + +@dataclass +class DummyOutput: + bell_calls: int = 0 + + def bell(self) -> None: + self.bell_calls += 1 + + +@dataclass +class DummyBuffer: + text: str = '' + complete_state: object | None = None + complete_next_calls: int = 0 + cancel_completion_calls: int = 0 + open_in_editor_calls: list[bool] = field(default_factory=list) + start_completion_calls: list[dict[str, bool]] = field(default_factory=list) + start_selection_calls: list[SelectionType] = field(default_factory=list) + transform_calls: list[tuple[int, int, Callable[[str], str]]] = field(default_factory=list) + inserted_text: list[str] = field(default_factory=list) + validate_calls: int = 0 + + def complete_next(self) -> None: + self.complete_next_calls += 1 + + def start_completion( + self, + select_first: bool = False, + insert_common_part: bool = False, + ) -> None: + self.start_completion_calls.append({ + 'select_first': select_first, + 'insert_common_part': insert_common_part, + }) + self.complete_state = object() + + def cancel_completion(self) -> None: + self.cancel_completion_calls += 1 + self.complete_state = None + + def open_in_editor(self, validate_and_handle: bool) -> None: + self.open_in_editor_calls.append(validate_and_handle) + + def start_selection(self, selection_type: SelectionType) -> None: + self.start_selection_calls.append(selection_type) + + def transform_region(self, start: int, end: int, handler: Callable[[str], str]) -> None: + self.transform_calls.append((start, end, handler)) + + def insert_text(self, text: str) -> None: + self.inserted_text.append(text) + + def validate_and_handle(self) -> None: + self.validate_calls += 1 + + +@dataclass +class DummyApp: + current_buffer: DummyBuffer + editing_mode: EditingMode = EditingMode.VI + ttimeoutlen: float | None = None + output: DummyOutput = field(default_factory=DummyOutput) + exit_calls: list[dict[str, Any]] = field(default_factory=list) + print_calls: list[Any] = field(default_factory=list) + + def exit(self, exception: type[BaseException], style: str) -> None: + self.exit_calls.append({'exception': exception, 'style': style}) + + def print_text(self, text: Any) -> None: + self.print_calls.append(text) + + +@dataclass +class DummyMyCli: + key_config: DummyKeysConfig + smart_completion: bool = True + multi_line: bool = False + key_bindings_mode: str = 'vi' + highlight_preview: bool = True + syntax_style: str = 'native' + emacs_ttimeoutlen: float = 1.5 + vi_ttimeoutlen: float = 0.5 + sqlexecute: object = field(default_factory=object) + prettify_calls: list[str] = field(default_factory=list) + unprettify_calls: list[str] = field(default_factory=list) + + def __post_init__(self) -> None: + self.completer = SimpleNamespace(smart_completion=self.smart_completion) + self.key_bindings = self.key_bindings_mode + self.config = {'keys': self.key_config} + + def handle_prettify_binding(self, text: str) -> str: + self.prettify_calls.append(text) + return text + + def handle_unprettify_binding(self, text: str) -> str: + self.unprettify_calls.append(text) + return text + + +def make_event(buffer: DummyBuffer | None = None) -> SimpleNamespace: + active_buffer = buffer or DummyBuffer() + app = DummyApp(current_buffer=active_buffer) + return SimpleNamespace(app=app, current_buffer=active_buffer) + + +def binding_handler(kb: prompt_toolkit.key_binding.KeyBindings, *keys: str | Keys) -> Callable[[Any], None]: + expected = tuple(keys) + for binding in kb.bindings: + if binding.keys == expected: + return cast(Callable[[Any], None], binding.handler) + raise AssertionError(f'binding not found for keys={expected!r}') + + +def binding_filter(kb: prompt_toolkit.key_binding.KeyBindings, *keys: str | Keys) -> Any: + expected = tuple(keys) + for binding in kb.bindings: + if binding.keys == expected: + return binding.filter + raise AssertionError(f'binding not found for keys={expected!r}') + + +def binding(kb: prompt_toolkit.key_binding.KeyBindings, *keys: str | Keys) -> Any: + expected = tuple(keys) + for entry in kb.bindings: + if entry.keys == expected: + return entry + raise AssertionError(f'binding not found for keys={expected!r}') + + +def patch_filter_app(monkeypatch, app: DummyApp) -> None: + monkeypatch.setitem(key_bindings.emacs_mode.func.__globals__, 'get_app', lambda: app) + monkeypatch.setitem(key_bindings.completion_is_selected.func.__globals__, 'get_app', lambda: app) + monkeypatch.setitem(key_bindings.control_is_searchable.func.__globals__, 'get_app', lambda: app) + + +def test_ctrl_d_condition_depends_on_empty_buffer(monkeypatch) -> None: + monkeypatch.setattr(key_bindings, 'get_app', lambda: SimpleNamespace(current_buffer=SimpleNamespace(text=''))) + assert key_bindings.ctrl_d_condition() is True + + monkeypatch.setattr(key_bindings, 'get_app', lambda: SimpleNamespace(current_buffer=SimpleNamespace(text='select 1'))) + assert key_bindings.ctrl_d_condition() is False + + +def test_in_completion_depends_on_complete_state(monkeypatch) -> None: + monkeypatch.setattr(key_bindings, 'get_app', lambda: SimpleNamespace(current_buffer=SimpleNamespace(complete_state=object()))) + assert key_bindings.in_completion() is True + + monkeypatch.setattr(key_bindings, 'get_app', lambda: SimpleNamespace(current_buffer=SimpleNamespace(complete_state=None))) + assert key_bindings.in_completion() is False + + +def test_print_f1_help_prints_inline_help_and_docs_url(monkeypatch) -> None: + app = DummyApp(current_buffer=DummyBuffer()) + monkeypatch.setattr(key_bindings, 'get_app', lambda: app) + + key_bindings.print_f1_help() + + assert app.print_calls == [ + '\n', + [ + ('', 'Inline help — type "'), + ('bold', 'help'), + ('', '" or "'), + ('bold', r'\?'), + ('', '"\n'), + ], + [ + ('', 'Docs index — '), + ('bold', key_bindings.DOCS_URL), + ('', '\n'), + ], + '\n', + ] + + +def test_edit_and_execute_opens_editor_without_validation() -> None: + event = make_event() + + key_bindings.edit_and_execute(cast(KeyPressEvent, event)) + + assert event.current_buffer.open_in_editor_calls == [False] + + +@pytest.mark.parametrize('keys', ((Keys.F1,), (Keys.Escape, '[', 'P'))) +def test_f1_bindings_open_docs_show_help_and_invalidate(monkeypatch, keys: tuple[str | Keys, ...]) -> None: + mycli = DummyMyCli(DummyKeysConfig()) + kb = key_bindings.mycli_bindings(mycli) + event = make_event() + browser_calls: list[str] = [] + terminal_calls: list[Callable[[], None]] = [] + invalidated: list[DummyApp] = [] + + monkeypatch.setattr(key_bindings.webbrowser, 'open_new_tab', lambda url: browser_calls.append(url)) + monkeypatch.setattr( + key_bindings.prompt_toolkit.application, + 'run_in_terminal', + lambda fn: terminal_calls.append(fn), + ) + monkeypatch.setattr(key_bindings, 'safe_invalidate_display', lambda app: invalidated.append(app)) + + binding_handler(kb, *keys)(event) + + assert browser_calls == [key_bindings.DOCS_URL] + assert terminal_calls == [key_bindings.print_f1_help] + assert invalidated == [event.app] + + +@pytest.mark.parametrize('keys', ((Keys.F2,), (Keys.Escape, '[', 'Q'))) +def test_f2_bindings_toggle_smart_completion(keys: tuple[str | Keys, ...]) -> None: + mycli = DummyMyCli(DummyKeysConfig(), smart_completion=True) + kb = key_bindings.mycli_bindings(mycli) + + binding_handler(kb, *keys)(make_event()) + + assert mycli.completer.smart_completion is False + + +@pytest.mark.parametrize('keys', ((Keys.F3,), (Keys.Escape, '[', 'R'))) +def test_f3_bindings_toggle_multiline_mode(keys: tuple[str | Keys, ...]) -> None: + mycli = DummyMyCli(DummyKeysConfig(), multi_line=False) + kb = key_bindings.mycli_bindings(mycli) + + binding_handler(kb, *keys)(make_event()) + + assert mycli.multi_line is True + + +@pytest.mark.parametrize( + ('keys', 'initial_mode', 'expected_mode', 'expected_editing_mode', 'expected_timeout'), + ( + ((Keys.F4,), 'vi', 'emacs', EditingMode.EMACS, 1.5), + ((Keys.F4,), 'emacs', 'vi', EditingMode.VI, 0.5), + ((Keys.Escape, '[', 'S'), 'vi', 'emacs', EditingMode.EMACS, 1.5), + ((Keys.Escape, '[', 'S'), 'emacs', 'vi', EditingMode.VI, 0.5), + ), +) +def test_f4_bindings_toggle_key_binding_modes( + keys: tuple[str | Keys, ...], + initial_mode: str, + expected_mode: str, + expected_editing_mode: EditingMode, + expected_timeout: float, +) -> None: + mycli = DummyMyCli(DummyKeysConfig(), key_bindings_mode=initial_mode) + kb = key_bindings.mycli_bindings(mycli) + event = make_event() + + binding_handler(kb, *keys)(event) + + assert mycli.key_bindings == expected_mode + assert event.app.editing_mode == expected_editing_mode + assert event.app.ttimeoutlen == expected_timeout + + +def test_tab_binding_uses_toolkit_default_to_start_completion() -> None: + mycli = DummyMyCli(DummyKeysConfig(behaviors={'tab': ['toolkit_default']})) + kb = key_bindings.mycli_bindings(mycli) + event = make_event(DummyBuffer(text='sel')) + + binding_handler(kb, Keys.ControlI)(event) + + assert event.app.current_buffer.start_completion_calls == [{'select_first': True, 'insert_common_part': False}] + assert event.app.current_buffer.complete_next_calls == 0 + + +def test_tab_binding_uses_toolkit_default_to_advance_existing_completion() -> None: + mycli = DummyMyCli(DummyKeysConfig(behaviors={'tab': ['toolkit_default']})) + kb = key_bindings.mycli_bindings(mycli) + event = make_event(DummyBuffer(text='sel', complete_state=object())) + + binding_handler(kb, Keys.ControlI)(event) + + assert event.app.current_buffer.complete_next_calls == 1 + + +@pytest.mark.parametrize( + ('behaviors', 'expected_start', 'expected_complete_next', 'expected_cancel'), + ( + (['advance'], [], 1, 0), + (['cancel'], [], 0, 1), + (['advancing_summon'], [{'select_first': True, 'insert_common_part': False}], 0, 0), + (['prefixing_summon'], [{'select_first': False, 'insert_common_part': True}], 0, 0), + (['summon'], [{'select_first': False, 'insert_common_part': False}], 0, 0), + ), +) +def test_tab_binding_supports_configured_behaviors( + behaviors: list[str], + expected_start: list[dict[str, bool]], + expected_complete_next: int, + expected_cancel: int, +) -> None: + mycli = DummyMyCli(DummyKeysConfig(behaviors={'tab': behaviors})) + kb = key_bindings.mycli_bindings(mycli) + complete_state = object() if behaviors[0] in {'advance', 'cancel'} else None + event = make_event(DummyBuffer(text='sel', complete_state=complete_state)) + + binding_handler(kb, Keys.ControlI)(event) + + assert event.app.current_buffer.start_completion_calls == expected_start + assert event.app.current_buffer.complete_next_calls == expected_complete_next + assert event.app.current_buffer.cancel_completion_calls == expected_cancel + + +def test_escape_binding_cancels_completion_menu(monkeypatch) -> None: + mycli = DummyMyCli(DummyKeysConfig()) + kb = key_bindings.mycli_bindings(mycli) + event = make_event(DummyBuffer(complete_state=object())) + monkeypatch.setattr(key_bindings, 'get_app', lambda: event.app) + + assert binding(kb, Keys.Escape).eager() is True + assert binding_filter(kb, Keys.Escape)() is True + + inactive_event = make_event(DummyBuffer(complete_state=None)) + monkeypatch.setattr(key_bindings, 'get_app', lambda: inactive_event.app) + assert binding_filter(kb, Keys.Escape)() is False + + monkeypatch.setattr(key_bindings, 'get_app', lambda: event.app) + + binding_handler(kb, Keys.Escape)(event) + + assert event.app.current_buffer.cancel_completion_calls == 1 + assert event.app.current_buffer.complete_state is None + + +def test_control_space_toolkit_default_starts_selection_for_non_empty_text() -> None: + mycli = DummyMyCli(DummyKeysConfig(behaviors={'control_space': ['toolkit_default']})) + kb = key_bindings.mycli_bindings(mycli) + event = make_event(DummyBuffer(text='abc')) + + binding_handler(kb, Keys.ControlAt)(event) + + assert event.app.current_buffer.start_selection_calls == [SelectionType.CHARACTERS] + + +def test_control_space_toolkit_default_is_noop_for_empty_text() -> None: + mycli = DummyMyCli(DummyKeysConfig(behaviors={'control_space': ['toolkit_default']})) + kb = key_bindings.mycli_bindings(mycli) + event = make_event(DummyBuffer(text='')) + + binding_handler(kb, Keys.ControlAt)(event) + + assert event.app.current_buffer.start_selection_calls == [] + assert event.app.current_buffer.start_completion_calls == [] + + +@pytest.mark.parametrize( + ('behaviors', 'expected_start', 'expected_complete_next', 'expected_cancel'), + ( + (['advance'], [], 1, 0), + (['cancel'], [], 0, 1), + (['advancing_summon'], [{'select_first': True, 'insert_common_part': False}], 0, 0), + (['prefixing_summon'], [{'select_first': False, 'insert_common_part': True}], 0, 0), + (['summon'], [{'select_first': False, 'insert_common_part': False}], 0, 0), + ), +) +def test_control_space_supports_completion_behaviors( + behaviors: list[str], + expected_start: list[dict[str, bool]], + expected_complete_next: int, + expected_cancel: int, +) -> None: + mycli = DummyMyCli(DummyKeysConfig(behaviors={'control_space': behaviors})) + kb = key_bindings.mycli_bindings(mycli) + complete_state = object() if behaviors[0] in {'advance', 'cancel'} else None + event = make_event(DummyBuffer(text='sel', complete_state=complete_state)) + + binding_handler(kb, Keys.ControlAt)(event) + + assert event.app.current_buffer.start_completion_calls == expected_start + assert event.app.current_buffer.complete_next_calls == expected_complete_next + assert event.app.current_buffer.cancel_completion_calls == expected_cancel + + +@pytest.mark.parametrize( + ('keys', 'handler_name'), + ( + ((Keys.ControlX, 'p'), 'handle_prettify_binding'), + ((Keys.ControlX, 'u'), 'handle_unprettify_binding'), + ), +) +def test_prettify_bindings_transform_non_empty_buffer( + monkeypatch, + keys: tuple[str | Keys, ...], + handler_name: str, +) -> None: + mycli = DummyMyCli(DummyKeysConfig(), key_bindings_mode='emacs') + kb = key_bindings.mycli_bindings(mycli) + event = make_event(DummyBuffer(text='select 1')) + event.app.editing_mode = EditingMode.EMACS + patch_filter_app(monkeypatch, event.app) + + assert binding_filter(kb, *keys)() is True + + binding_handler(kb, *keys)(event) + + assert len(event.app.current_buffer.transform_calls) == 1 + start, end, handler = event.app.current_buffer.transform_calls[0] + assert (start, end) == (0, len('select 1')) + assert handler.func is getattr(key_bindings.key_binding_utils, handler_name) + assert handler.args == (mycli,) + + +@pytest.mark.parametrize('keys', ((Keys.ControlX, 'p'), (Keys.ControlX, 'u'))) +def test_prettify_bindings_skip_empty_buffer(monkeypatch, keys: tuple[str | Keys, ...]) -> None: + mycli = DummyMyCli(DummyKeysConfig(), key_bindings_mode='emacs') + kb = key_bindings.mycli_bindings(mycli) + event = make_event(DummyBuffer(text='')) + event.app.editing_mode = EditingMode.EMACS + patch_filter_app(monkeypatch, event.app) + + binding_handler(kb, *keys)(event) + + assert event.app.current_buffer.transform_calls == [] + + +@pytest.mark.parametrize( + ('keys', 'expected_text'), + ( + ((Keys.ControlO, 'd'), 'DATE'), + ((Keys.ControlO, Keys.ControlD), "'DATE'"), + ((Keys.ControlO, 't'), 'DATETIME'), + ((Keys.ControlO, Keys.ControlT), "'DATETIME'"), + ), +) +def test_date_and_datetime_bindings_insert_shortcuts( + monkeypatch, + keys: tuple[str | Keys, ...], + expected_text: str, +) -> None: + mycli = DummyMyCli(DummyKeysConfig(), key_bindings_mode='emacs') + kb = key_bindings.mycli_bindings(mycli) + event = make_event() + event.app.editing_mode = EditingMode.EMACS + patch_filter_app(monkeypatch, event.app) + + monkeypatch.setattr( + key_bindings.key_binding_utils, + 'server_date', + lambda _sqlexecute, quoted=False: "'DATE'" if quoted else 'DATE', + ) + monkeypatch.setattr( + key_bindings.key_binding_utils, + 'server_datetime', + lambda _sqlexecute, quoted=False: "'DATETIME'" if quoted else 'DATETIME', + ) + + assert binding_filter(kb, *keys)() is True + + inactive_event = make_event() + inactive_event.app.editing_mode = EditingMode.VI + patch_filter_app(monkeypatch, inactive_event.app) + assert binding_filter(kb, *keys)() is False + + patch_filter_app(monkeypatch, event.app) + + binding_handler(kb, *keys)(event) + + assert event.app.current_buffer.inserted_text == [expected_text] + + +def test_control_r_uses_reverse_isearch_mode_when_configured(monkeypatch) -> None: + mycli = DummyMyCli(DummyKeysConfig(options={'control_r': 'reverse_isearch'}), key_bindings_mode='emacs') + kb = key_bindings.mycli_bindings(mycli) + event = make_event() + event.app.editing_mode = EditingMode.EMACS + event.app.layout = SimpleNamespace(current_control=BufferControl(search_buffer_control=SearchBufferControl())) + vi_mode_event = make_event() + vi_mode_event.app.editing_mode = EditingMode.VI + vi_mode_event.app.layout = SimpleNamespace(current_control=BufferControl(search_buffer_control=SearchBufferControl())) + calls: list[dict[str, Any]] = [] + patch_filter_app(monkeypatch, event.app) + + monkeypatch.setattr( + key_bindings, + 'search_history', + lambda *args, **kwargs: calls.append({'args': args, 'kwargs': kwargs}), + ) + + assert binding_filter(kb, Keys.ControlR)() is True + + inactive_event = make_event() + inactive_event.app.editing_mode = EditingMode.EMACS + inactive_event.app.layout = SimpleNamespace(current_control=object()) + patch_filter_app(monkeypatch, inactive_event.app) + assert binding_filter(kb, Keys.ControlR)() is False + + patch_filter_app(monkeypatch, vi_mode_event.app) + assert binding_filter(kb, Keys.ControlR)() is True + + patch_filter_app(monkeypatch, event.app) + + binding_handler(kb, Keys.ControlR)(event) + patch_filter_app(monkeypatch, vi_mode_event.app) + binding_handler(kb, Keys.ControlR)(vi_mode_event) + + assert calls == [ + {'args': (event,), 'kwargs': {'incremental': True}}, + {'args': (vi_mode_event,), 'kwargs': {'incremental': True}}, + ] + + +def test_control_r_and_alt_r_use_fzf_search_options(monkeypatch) -> None: + mycli = DummyMyCli(DummyKeysConfig(), key_bindings_mode='emacs') + kb = key_bindings.mycli_bindings(mycli) + calls: list[dict[str, Any]] = [] + + monkeypatch.setattr( + key_bindings, + 'search_history', + lambda *args, **kwargs: calls.append({'args': args, 'kwargs': kwargs}), + ) + + control_r_event = make_event() + alt_r_event = make_event() + control_r_event.app.editing_mode = EditingMode.EMACS + alt_r_event.app.editing_mode = EditingMode.EMACS + control_r_event.app.layout = SimpleNamespace(current_control=BufferControl(search_buffer_control=SearchBufferControl())) + alt_r_event.app.layout = SimpleNamespace(current_control=BufferControl(search_buffer_control=SearchBufferControl())) + patch_filter_app(monkeypatch, control_r_event.app) + assert binding_filter(kb, Keys.ControlR)() is True + + inactive_control_r_event = make_event() + inactive_control_r_event.app.editing_mode = EditingMode.EMACS + inactive_control_r_event.app.layout = SimpleNamespace(current_control=object()) + patch_filter_app(monkeypatch, inactive_control_r_event.app) + assert binding_filter(kb, Keys.ControlR)() is False + + vi_mode_control_r_event = make_event() + vi_mode_control_r_event.app.editing_mode = EditingMode.VI + vi_mode_control_r_event.app.layout = SimpleNamespace(current_control=BufferControl(search_buffer_control=SearchBufferControl())) + patch_filter_app(monkeypatch, vi_mode_control_r_event.app) + assert binding_filter(kb, Keys.ControlR)() is True + + patch_filter_app(monkeypatch, control_r_event.app) + binding_handler(kb, Keys.ControlR)(control_r_event) + patch_filter_app(monkeypatch, vi_mode_control_r_event.app) + binding_handler(kb, Keys.ControlR)(vi_mode_control_r_event) + patch_filter_app(monkeypatch, alt_r_event.app) + assert binding_filter(kb, Keys.Escape, 'r')() is True + + vi_mode_event = make_event() + vi_mode_event.app.editing_mode = EditingMode.VI + vi_mode_event.app.layout = SimpleNamespace(current_control=BufferControl(search_buffer_control=SearchBufferControl())) + patch_filter_app(monkeypatch, vi_mode_event.app) + assert binding_filter(kb, Keys.Escape, 'r')() is False + + non_searchable_event = make_event() + non_searchable_event.app.editing_mode = EditingMode.EMACS + non_searchable_event.app.layout = SimpleNamespace(current_control=object()) + patch_filter_app(monkeypatch, non_searchable_event.app) + assert binding_filter(kb, Keys.Escape, 'r')() is False + + patch_filter_app(monkeypatch, alt_r_event.app) + binding_handler(kb, Keys.Escape, 'r')(alt_r_event) + + assert calls == [ + { + 'args': (control_r_event,), + 'kwargs': { + 'highlight_preview': True, + 'highlight_style': 'native', + }, + }, + { + 'args': (vi_mode_control_r_event,), + 'kwargs': { + 'highlight_preview': True, + 'highlight_style': 'native', + }, + }, + { + 'args': (alt_r_event,), + 'kwargs': { + 'highlight_preview': True, + 'highlight_style': 'native', + }, + }, + ] + + +@pytest.mark.parametrize( + ('mode', 'expected_exit_calls', 'expected_bells'), + ( + ('exit', [{'exception': EOFError, 'style': 'class:exiting'}], 0), + ('bell', [], 1), + ), +) +def test_control_d_binding_exits_or_bells( + monkeypatch, + mode: str, + expected_exit_calls: list[dict[str, Any]], + expected_bells: int, +) -> None: + mycli = DummyMyCli(DummyKeysConfig(options={'control_d': mode})) + kb = key_bindings.mycli_bindings(mycli) + event = make_event() + monkeypatch.setattr(key_bindings, 'get_app', lambda: event.app) + + assert binding_filter(kb, Keys.ControlD)() is True + + inactive_event = make_event(DummyBuffer(text='select 1')) + monkeypatch.setattr(key_bindings, 'get_app', lambda: inactive_event.app) + assert binding_filter(kb, Keys.ControlD)() is False + + monkeypatch.setattr(key_bindings, 'get_app', lambda: event.app) + + binding_handler(kb, Keys.ControlD)(event) + + assert event.app.exit_calls == expected_exit_calls + assert event.app.output.bell_calls == expected_bells + + +def test_enter_binding_closes_completion_menu(monkeypatch) -> None: + mycli = DummyMyCli(DummyKeysConfig()) + kb = key_bindings.mycli_bindings(mycli) + event = make_event(DummyBuffer(text='sel', complete_state=SimpleNamespace(current_completion=object()))) + patch_filter_app(monkeypatch, event.app) + + assert binding_filter(kb, Keys.ControlM)() is True + + inactive_event = make_event(DummyBuffer(text='sel', complete_state=SimpleNamespace(current_completion=None))) + patch_filter_app(monkeypatch, inactive_event.app) + assert binding_filter(kb, Keys.ControlM)() is False + + patch_filter_app(monkeypatch, event.app) + + binding_handler(kb, Keys.ControlM)(event) + + assert event.current_buffer.complete_state is None + assert event.app.current_buffer.complete_state is None + + +@pytest.mark.parametrize( + ('multi_line', 'expected_validate_calls', 'expected_inserted_text'), + ( + (True, 1, []), + (False, 0, ['\n']), + ), +) +def test_alt_enter_binding_validates_or_inserts_newline( + multi_line: bool, + expected_validate_calls: int, + expected_inserted_text: list[str], +) -> None: + mycli = DummyMyCli(DummyKeysConfig(), multi_line=multi_line) + kb = key_bindings.mycli_bindings(mycli) + event = make_event() + + binding_handler(kb, Keys.Escape, Keys.ControlM)(event) + + assert event.app.current_buffer.validate_calls == expected_validate_calls + assert event.app.current_buffer.inserted_text == expected_inserted_text diff --git a/test/pytests/test_main.py b/test/pytests/test_main.py new file mode 100644 index 00000000..8541f808 --- /dev/null +++ b/test/pytests/test_main.py @@ -0,0 +1,2479 @@ +# type: ignore + +from collections import namedtuple +from contextlib import redirect_stderr, redirect_stdout +import csv +import io +import os +import shutil +from tempfile import NamedTemporaryFile +from textwrap import dedent +from types import SimpleNamespace +from typing import Any, cast + +import click +from click.testing import CliRunner +import prompt_toolkit +from prompt_toolkit.formatted_text import ( + FormattedText, + to_formatted_text, + to_plain_text, +) +import pymysql +from pymysql.err import OperationalError +import pytest + +from mycli import main +from mycli.constants import ( + DEFAULT_DATABASE, + DEFAULT_HOST, + DEFAULT_PORT, + DEFAULT_USER, + TEST_DATABASE, +) +from mycli.main import EMPTY_PASSWORD_FLAG_SENTINEL, MyCli, click_entrypoint +import mycli.main_modes.repl as repl_mode +import mycli.output as output_module +import mycli.packages.special +from mycli.packages.special.main import COMMANDS as SPECIAL_COMMANDS +from mycli.packages.sqlresult import SQLResult +from mycli.sqlexecute import ServerInfo, SQLExecute +from test.utils import ( + DATABASE, + HOST, + PASSWORD, + PORT, + TEMPFILE_PREFIX, + USER, + DummyFormatter, + DummyLogger, + FakeCursorBase, + RecordingSQLExecute, + ReusableLock, + call_click_entrypoint_direct, + dbtest, + make_bare_mycli, + make_dummy_mycli_class, + run, +) + +pytests_dir = os.path.abspath(os.path.dirname(__file__)) +project_root_dir = os.path.abspath(os.path.join(pytests_dir, '..', '..')) +default_config_file = os.path.join(project_root_dir, 'test', 'myclirc') +login_path_file = os.path.join(project_root_dir, 'test', 'mylogin.cnf') + +os.environ["MYSQL_TEST_LOGIN_FILE"] = login_path_file +CLI_ARGS_WITHOUT_DB = [ + "--user", + USER, + "--host", + HOST, + "--port", + PORT, + "--password", + PASSWORD, + "--myclirc", + default_config_file, + "--defaults-file", + default_config_file, +] +CLI_ARGS = CLI_ARGS_WITHOUT_DB + [TEST_DATABASE] + + +@dbtest +def test_binary_display_hex(executor): + m = MyCli() + m.sqlexecute = SQLExecute( + None, + USER, + PASSWORD, + HOST, + PORT, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + m.explicit_pager = False + sqlresult = next(m.sqlexecute.run("select b'01101010' AS binary_test")) + formatted = m.format_sqlresult( + sqlresult, + is_expanded=False, + is_redirected=False, + null_string="", + numeric_alignment="right", + binary_display="hex", + max_width=None, + ) + f = io.StringIO() + with redirect_stdout(f): + m.output(formatted, sqlresult) + expected = " 0x6a " + output = f.getvalue() + assert expected in output + + +@dbtest +def test_binary_display_utf8(executor): + m = MyCli() + m.sqlexecute = SQLExecute( + None, + USER, + PASSWORD, + HOST, + PORT, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + m.explicit_pager = False + sqlresult = next(m.sqlexecute.run("select b'01101010' AS binary_test")) + formatted = m.format_sqlresult( + sqlresult, + is_expanded=False, + is_redirected=False, + null_string="", + numeric_alignment="right", + binary_display="utf8", + max_width=None, + ) + f = io.StringIO() + with redirect_stdout(f): + m.output(formatted, sqlresult) + expected = " j " + output = f.getvalue() + assert expected in output + + +@dbtest +def test_select_from_empty_table(executor): + run(executor, """create table t1(id int)""") + sql = "select * from t1" + runner = CliRunner() + result = runner.invoke(click_entrypoint, args=CLI_ARGS + ["-t"], input=sql) + expected = dedent("""\ + +----+ + | id | + +----+ + +----+""") + assert expected in result.output + + +def test_filtered_sys_argv_maps_single_dash_h_to_help(monkeypatch): + import mycli.main + + monkeypatch.setattr(mycli.main.sys, 'argv', ['mycli', '-h']) + + assert mycli.main.filtered_sys_argv() == ['--help'] + + +def test_filtered_sys_argv_preserves_host_option_usage(monkeypatch): + import mycli.main + + monkeypatch.setattr(mycli.main.sys, 'argv', ['mycli', '-h', 'example.com']) + + assert mycli.main.filtered_sys_argv() == ['-h', 'example.com'] + + +def test_main_dash_h_and_help_have_equivalent_output(monkeypatch): + import mycli.main + + def run_main(argv): + stdout = io.StringIO() + stderr = io.StringIO() + monkeypatch.setattr(mycli.main.sys, 'argv', argv) + with redirect_stdout(stdout), redirect_stderr(stderr): + result = mycli.main.main() + return result, stdout.getvalue(), stderr.getvalue() + + dash_h_result, dash_h_stdout, dash_h_stderr = run_main(['mycli', '-h']) + dash_help_result, dash_help_stdout, dash_help_stderr = run_main(['mycli', '--help']) + + assert dash_h_result == 0 + assert dash_help_result == 0 + assert dash_h_stdout == dash_help_stdout + assert dash_h_stderr == dash_help_stderr + + +@dbtest +def test_ssl_mode_on(executor, capsys): + runner = CliRunner() + ssl_mode = "on" + sql = "select * from performance_schema.session_status where variable_name = 'Ssl_cipher'" + result = runner.invoke(click_entrypoint, args=CLI_ARGS + ["--csv", "--ssl-mode", ssl_mode], input=sql) + result_dict = next(csv.DictReader(result.stdout.split("\n"))) + ssl_cipher = result_dict.get("VARIABLE_VALUE", None) + assert ssl_cipher + + +@dbtest +def test_ssl_mode_auto(executor, capsys): + runner = CliRunner() + ssl_mode = "auto" + sql = "select * from performance_schema.session_status where variable_name = 'Ssl_cipher'" + result = runner.invoke(click_entrypoint, args=CLI_ARGS + ["--csv", "--ssl-mode", ssl_mode], input=sql) + result_dict = next(csv.DictReader(result.stdout.split("\n"))) + ssl_cipher = result_dict.get("VARIABLE_VALUE", None) + assert ssl_cipher + + +@dbtest +def test_ssl_mode_off(executor, capsys): + runner = CliRunner() + ssl_mode = "off" + sql = "select * from performance_schema.session_status where variable_name = 'Ssl_cipher'" + result = runner.invoke(click_entrypoint, args=CLI_ARGS + ["--csv", "--ssl-mode", ssl_mode], input=sql) + result_dict = next(csv.DictReader(result.stdout.split("\n"))) + ssl_cipher = result_dict.get("VARIABLE_VALUE", None) + assert not ssl_cipher + + +@dbtest +def test_ssl_mode_overrides_ssl(executor, capsys): + runner = CliRunner() + ssl_mode = "off" + sql = "select * from performance_schema.session_status where variable_name = 'Ssl_cipher'" + result = runner.invoke(click_entrypoint, args=CLI_ARGS + ["--csv", "--ssl-mode", ssl_mode, "--ssl"], input=sql) + result_dict = next(csv.DictReader(result.stdout.split("\n"))) + ssl_cipher = result_dict.get("VARIABLE_VALUE", None) + assert not ssl_cipher + + +@dbtest +def test_ssl_mode_overrides_no_ssl(executor, capsys): + runner = CliRunner() + ssl_mode = "on" + sql = "select * from performance_schema.session_status where variable_name = 'Ssl_cipher'" + result = runner.invoke(click_entrypoint, args=CLI_ARGS + ["--csv", "--ssl-mode", ssl_mode, "--no-ssl"], input=sql) + result_dict = next(csv.DictReader(result.stdout.split("\n"))) + ssl_cipher = result_dict.get("VARIABLE_VALUE", None) + assert ssl_cipher + + +@dbtest +def test_reconnect_database_is_selected(executor, capsys): + m = MyCli() + m.register_special_commands() + m.sqlexecute = SQLExecute( + None, + USER, + PASSWORD, + HOST, + PORT, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + try: + next(m.sqlexecute.run(f"use {DATABASE}")) + next(m.sqlexecute.run(f"kill {m.sqlexecute.connection_id}")) + except OperationalError: + pass # expected as the connection was killed + except Exception as e: + raise e + m.reconnect() + try: + next(m.sqlexecute.run("show tables")).rows.fetchall() + except Exception as e: + raise e + + +@dbtest +def test_reconnect_no_database(executor, capsys): + m = MyCli() + m.register_special_commands() + m.sqlexecute = SQLExecute( + None, + USER, + PASSWORD, + HOST, + PORT, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + sql = "\\r" + result = next(mycli.packages.special.execute(executor, sql)) + stdout, _stderr = capsys.readouterr() + assert result.status is None + assert "Already connected" in stdout + + +@dbtest +def test_reconnect_with_different_database(executor): + m = MyCli() + m.register_special_commands() + m.sqlexecute = SQLExecute( + None, + USER, + PASSWORD, + HOST, + PORT, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + database_1 = TEST_DATABASE + database_2 = DEFAULT_DATABASE + sql_1 = f"use {database_1}" + sql_2 = f"\\r {database_2}" + _result_1 = next(mycli.packages.special.execute(executor, sql_1)) + result_2 = next(mycli.packages.special.execute(executor, sql_2)) + expected = f'You are now connected to database "{database_2}" as user "{USER}"' + assert expected in result_2.status + + +@dbtest +def test_reconnect_with_same_database(executor): + m = MyCli() + m.register_special_commands() + m.sqlexecute = SQLExecute( + None, + USER, + PASSWORD, + HOST, + PORT, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + database = DEFAULT_DATABASE + sql = f"\\u {database}" + result = next(mycli.packages.special.execute(executor, sql)) + sql = f"\\r {database}" + result = next(mycli.packages.special.execute(executor, sql)) + expected = f'You are already connected to database "{database}" as user "{USER}"' + assert expected in result.status + + +@dbtest +def test_prompt_no_host_only_socket(executor): + mycli = MyCli() + mycli.prompt_format = "\\t \\u@\\h:\\d> " + mycli.sqlexecute = SQLExecute + mycli.sqlexecute.server_info = ServerInfo.from_version_string("8.0.44-0ubuntu0.24.04.1") + mycli.sqlexecute.host = None + mycli.sqlexecute.socket = "/var/run/mysqld/mysqld.sock" + mycli.sqlexecute.user = DEFAULT_USER + mycli.sqlexecute.dbname = DEFAULT_DATABASE + mycli.sqlexecute.port = DEFAULT_PORT + prompt = repl_mode.render_prompt_string(mycli, mycli.prompt_format, 0) + prompt_plain = to_plain_text(prompt) + assert prompt_plain == f"MySQL {DEFAULT_USER}@{DEFAULT_HOST}:{DEFAULT_DATABASE}> " + + +@dbtest +def test_prompt_socket_overrides_port(executor): + mycli = MyCli() + mycli.prompt_format = "\\t \\u@\\h:\\k \\d> " + mycli.sqlexecute = SQLExecute + mycli.sqlexecute.server_info = ServerInfo.from_version_string("8.0.44-0ubuntu0.24.04.1") + mycli.sqlexecute.host = None + mycli.sqlexecute.socket = "/var/run/mysqld/mysqld.sock" + mycli.sqlexecute.user = DEFAULT_USER + mycli.sqlexecute.dbname = DEFAULT_DATABASE + mycli.sqlexecute.port = DEFAULT_PORT + prompt = repl_mode.render_prompt_string(mycli, mycli.prompt_format, 0) + prompt_plain = to_plain_text(prompt) + assert prompt_plain == f"MySQL {DEFAULT_USER}@{DEFAULT_HOST}:mysqld.sock {DEFAULT_DATABASE}> " + + +@dbtest +def test_prompt_socket_short_host(executor): + mycli = MyCli() + mycli.prompt_format = "\\t \\u@\\H:\\k \\d> " + mycli.sqlexecute = SQLExecute + mycli.sqlexecute.server_info = ServerInfo.from_version_string("8.0.44-0ubuntu0.24.04.1") + mycli.sqlexecute.host = f'{DEFAULT_HOST}.localdomain' + mycli.sqlexecute.socket = None + mycli.sqlexecute.user = DEFAULT_USER + mycli.sqlexecute.dbname = DEFAULT_DATABASE + mycli.sqlexecute.port = DEFAULT_PORT + prompt = repl_mode.render_prompt_string(mycli, mycli.prompt_format, 0) + prompt_plain = to_plain_text(prompt) + assert prompt_plain == f"MySQL {DEFAULT_USER}@{DEFAULT_HOST}:{DEFAULT_PORT} {DEFAULT_DATABASE}> " + + +@dbtest +def test_enable_show_warnings(executor): + mycli = MyCli() + mycli.register_special_commands() + sql = "\\W" + result = run(executor, sql) + assert result[0]["status"] == "Show warnings enabled." + + +@dbtest +def test_disable_show_warnings(executor): + mycli = MyCli() + mycli.register_special_commands() + sql = "\\w" + result = run(executor, sql) + assert result[0]["status"] == "Show warnings disabled." + + +@dbtest +def test_output_ddl_with_warning_and_show_warnings_enabled(executor): + runner = CliRunner() + db = TEST_DATABASE + table = "table_that_definitely_does_not_exist_1234" + sql = f"DROP TABLE IF EXISTS {db}.{table}" + result = runner.invoke(click_entrypoint, args=CLI_ARGS + ["--show-warnings", "--no-warn"], input=sql) + expected = f"Level\tCode\tMessage\nNote\t1051\tUnknown table '{db}.table_that_definitely_does_not_exist_1234'\n" + assert expected in result.output + + +@dbtest +def test_output_with_warning_and_show_warnings_enabled(executor): + runner = CliRunner() + sql = "SELECT 1 + '0 foo'" + result = runner.invoke(click_entrypoint, args=CLI_ARGS + ["--show-warnings"], input=sql) + expected = "1 + '0 foo'\n1.0\nLevel\tCode\tMessage\nWarning\t1292\tTruncated incorrect DOUBLE value: '0 foo'\n" + assert expected in result.output + + +@dbtest +def test_output_with_warning_and_show_warnings_disabled(executor): + runner = CliRunner() + sql = "SELECT 1 + '0 foo'" + result = runner.invoke(click_entrypoint, args=CLI_ARGS + ["--no-show-warnings"], input=sql) + expected = "1 + '0 foo'\n1.0\nLevel\tCode\tMessage\nWarning\t1292\tTruncated incorrect DOUBLE value: '0 foo'\n" + assert expected not in result.output + + +@dbtest +def test_no_show_warnings_overrides_myclirc_setting(executor, tmp_path, monkeypatch): + monkeypatch.chdir(tmp_path) + runner = CliRunner() + sql = 'EXPLAIN SELECT 1' + expected = 'select 1' + + with NamedTemporaryFile(prefix=TEMPFILE_PREFIX, mode='w', delete=False) as myclirc: + myclirc.write( + dedent("""\ + [main] + show_warnings = True + """) + ) + myclirc.flush() + args = [ + '--user', + USER, + '--host', + HOST, + '--port', + PORT, + '--password', + PASSWORD, + '--myclirc', + myclirc.name, + '--defaults-file', + default_config_file, + TEST_DATABASE, + ] + + result = runner.invoke(click_entrypoint, args=args, input=sql) + assert expected in result.output + + result = runner.invoke(click_entrypoint, args=args + ['--no-show-warnings'], input=sql) + assert expected not in result.output + + try: + if os.path.exists(myclirc.name): + os.remove(myclirc.name) + except Exception as e: + print(f"An error occurred while attempting to delete the file: {e}") + + +@dbtest +def test_output_with_multiple_warnings_in_single_statement(executor): + runner = CliRunner() + sql = "SELECT 1 + '0 foo', 2 + '0 foo'" + result = runner.invoke(click_entrypoint, args=CLI_ARGS + ["--show-warnings"], input=sql) + expected = ( + "1 + '0 foo'\t2 + '0 foo'\n" + "1.0\t2.0\n" + "Level\tCode\tMessage\n" + "Warning\t1292\tTruncated incorrect DOUBLE value: '0 foo'\n" + "Warning\t1292\tTruncated incorrect DOUBLE value: '0 foo'\n" + ) + assert expected in result.output + + +@dbtest +def test_output_with_multiple_warnings_in_multiple_statements(executor): + runner = CliRunner() + sql = "SELECT 1 + '0 foo'; SELECT 2 + '0 foo'" + result = runner.invoke(click_entrypoint, args=CLI_ARGS + ["--show-warnings"], input=sql) + expected = ( + "1 + '0 foo'\n" + "1.0\n" + "Level\tCode\tMessage\n" + "Warning\t1292\tTruncated incorrect DOUBLE value: '0 foo'\n" + "2 + '0 foo'\n" + "2.0\n" + "Level\tCode\tMessage\n" + "Warning\t1292\tTruncated incorrect DOUBLE value: '0 foo'\n" + ) + assert expected in result.output + + +@dbtest +def test_execute_arg(executor): + run(executor, "create table test (a text)") + run(executor, 'insert into test values("abc")') + + sql = "select * from test;" + runner = CliRunner() + result = runner.invoke(click_entrypoint, args=CLI_ARGS + ["-e", sql]) + + assert result.exit_code == 0 + assert "abc" in result.output + + result = runner.invoke(click_entrypoint, args=CLI_ARGS + ["--execute", sql]) + + assert result.exit_code == 0 + assert "abc" in result.output + + expected = "a\nabc\n" + + assert expected in result.output + + +@dbtest +def test_execute_arg_with_checkpoint(executor): + run(executor, "create table test (a text)") + run(executor, 'insert into test values("abc")') + + sql = "select * from test;" + runner = CliRunner() + + with NamedTemporaryFile(prefix=TEMPFILE_PREFIX, mode="w", delete=False) as checkpoint: + checkpoint.close() + + result = runner.invoke(click_entrypoint, args=CLI_ARGS + ["--execute", sql, f"--checkpoint={checkpoint.name}"]) + assert result.exit_code == 0 + + with open(checkpoint.name, 'r') as f: + contents = f.read() + assert sql in contents + os.remove(checkpoint.name) + + sql = 'select 10 from nonexistent_table;' + result = runner.invoke(click_entrypoint, args=CLI_ARGS + ["--execute", sql, f"--checkpoint={checkpoint.name}"]) + assert result.exit_code != 0 + + with open(checkpoint.name, 'r') as f: + contents = f.read() + assert sql not in contents + + # delete=False means we should try to clean up + # we don't really need "try" here as open() would have already failed + try: + if os.path.exists(checkpoint.name): + os.remove(checkpoint.name) + except Exception as e: + print(f"An error occurred while attempting to delete the file: {e}") + + +@dbtest +def test_execute_arg_with_table(executor): + run(executor, "create table test (a text)") + run(executor, 'insert into test values("abc")') + + sql = "select * from test;" + runner = CliRunner() + result = runner.invoke(click_entrypoint, args=CLI_ARGS + ["-e", sql] + ["--table"]) + expected = "+-----+\n| a |\n+-----+\n| abc |\n+-----+\n" + + assert result.exit_code == 0 + assert expected in result.output + + +@dbtest +def test_execute_arg_with_csv(executor): + run(executor, "create table test (a text)") + run(executor, 'insert into test values("abc")') + + sql = "select * from test;" + runner = CliRunner() + result = runner.invoke(click_entrypoint, args=CLI_ARGS + ["-e", sql] + ["--csv"]) + expected = '"a"\n"abc"\n' + + assert result.exit_code == 0 + assert expected in "".join(result.output) + + +@dbtest +def test_batch_mode(executor): + run(executor, """create table test(a text)""") + run(executor, """insert into test values('abc'), ('def'), ('ghi')""") + + sql = "select count(*) from test;\nselect * from test limit 1;" + + runner = CliRunner() + result = runner.invoke(click_entrypoint, args=CLI_ARGS, input=sql) + + assert result.exit_code == 0 + assert "count(*)\n3\na\nabc\n" in "".join(result.output) + + +@dbtest +def test_batch_mode_multiline_statement(executor): + run(executor, """create table test(a text)""") + run(executor, """insert into test values('abc'), ('def'), ('ghi')""") + + sql = "select count(*)\nfrom test;\nselect * from test limit 1;" + + runner = CliRunner() + result = runner.invoke(click_entrypoint, args=CLI_ARGS, input=sql) + + assert result.exit_code == 0 + assert "count(*)\n3\na\nabc\n" in "".join(result.output) + + +@dbtest +def test_batch_mode_table(executor): + run(executor, """create table test(a text)""") + run(executor, """insert into test values('abc'), ('def'), ('ghi')""") + + sql = "select count(*) from test;\nselect * from test limit 1;" + + runner = CliRunner() + result = runner.invoke(click_entrypoint, args=CLI_ARGS + ["-t"], input=sql) + + expected = dedent("""\ + +----------+ + | count(*) | + +----------+ + | 3 | + +----------+ + +-----+ + | a | + +-----+ + | abc | + +-----+""") + + assert result.exit_code == 0 + assert expected in result.output + + +@dbtest +def test_batch_mode_csv(executor): + run(executor, """create table test(a text, b text)""") + run(executor, """insert into test (a, b) values('abc', 'de\nf'), ('ghi', 'jkl')""") + + sql = "select * from test;" + + runner = CliRunner() + result = runner.invoke(click_entrypoint, args=CLI_ARGS + ["--csv"], input=sql) + + expected = '"a","b"\n"abc","de\nf"\n"ghi","jkl"\n' + + assert result.exit_code == 0 + assert expected in "".join(result.output) + + +def test_help_strings_end_with_periods(): + """Make sure click options have help text that end with a period.""" + for param in click_entrypoint.params: + if isinstance(param, click.core.Option): + assert hasattr(param, "help") + assert param.help.endswith(".") + + +def test_command_descriptions_end_with_periods(): + """Make sure that mycli commands' descriptions end with a period.""" + MyCli() + for _, command in SPECIAL_COMMANDS.items(): + assert command.description.endswith(".") + + +def output(monkeypatch, terminal_size, testdata, explicit_pager, expect_pager): + global clickoutput + clickoutput = "" + m = MyCli(myclirc=default_config_file) + + class TestOutput: + def get_size(self): + size = namedtuple("Size", "rows columns") + size.columns, size.rows = terminal_size + return size + + class TestExecute: + host = "test" + user = "test" + dbname = "test" + server_info = ServerInfo.from_version_string("unknown") + port = 0 + socket = '' + + def server_type(self): + return ["test"] + + class TestPromptSession: + output = TestOutput() + app = None + + m.prompt_session = TestPromptSession() + m.sqlexecute = TestExecute() + m.explicit_pager = explicit_pager + + def echo_via_pager(s): + assert expect_pager + global clickoutput + clickoutput += "".join(s) + + def secho(s): + assert not expect_pager + global clickoutput + clickoutput += s + "\n" + + monkeypatch.setattr(click, "echo_via_pager", echo_via_pager) + monkeypatch.setattr(click, "secho", secho) + m.output(testdata, SQLResult()) + if clickoutput.endswith("\n"): + clickoutput = clickoutput[:-1] + assert clickoutput == "\n".join(testdata) + + +def test_conditional_pager(monkeypatch): + testdata = "Lorem ipsum dolor sit amet consectetur adipiscing elit sed do".split(" ") + # User didn't set pager, output doesn't fit screen -> pager + output(monkeypatch, terminal_size=(5, 10), testdata=testdata, explicit_pager=False, expect_pager=True) + # User didn't set pager, output fits screen -> no pager + output(monkeypatch, terminal_size=(20, 20), testdata=testdata, explicit_pager=False, expect_pager=False) + # User manually configured pager, output doesn't fit screen -> pager + output(monkeypatch, terminal_size=(5, 10), testdata=testdata, explicit_pager=True, expect_pager=True) + # User manually configured pager, output fit screen -> pager + output(monkeypatch, terminal_size=(20, 20), testdata=testdata, explicit_pager=True, expect_pager=True) + + SPECIAL_COMMANDS["nopager"].handler() + output(monkeypatch, terminal_size=(5, 10), testdata=testdata, explicit_pager=False, expect_pager=False) + SPECIAL_COMMANDS["pager"].handler("") + + +def test_reserved_space_is_integer(monkeypatch): + """Make sure that reserved space is returned as an integer.""" + + def stub_terminal_size(): + return (5, 5) + + with monkeypatch.context() as m: + m.setattr(shutil, "get_terminal_size", stub_terminal_size) + mycli = MyCli() + assert isinstance(mycli.get_reserved_space(), int) + + +def test_list_dsn(monkeypatch): + monkeypatch.setattr(MyCli, "system_config_files", []) + monkeypatch.setattr(MyCli, "pwd_config_file", os.devnull) + runner = CliRunner() + # keep Windows from locking the file with delete=False + with NamedTemporaryFile(prefix=TEMPFILE_PREFIX, mode="w", delete=False) as myclirc: + myclirc.write( + dedent("""\ + [alias_dsn] + test = mysql://test/test + """) + ) + myclirc.flush() + args = ["--list-dsn", "--myclirc", myclirc.name] + result = runner.invoke(click_entrypoint, args=args) + assert result.output == "test\n" + result = runner.invoke(click_entrypoint, args=args + ["--verbose"]) + assert result.output == "test : mysql://test/test\n" + + # delete=False means we should try to clean up + try: + if os.path.exists(myclirc.name): + os.remove(myclirc.name) + except Exception as e: + print(f"An error occurred while attempting to delete the file: {e}") + + +def test_list_ssh_config(): + runner = CliRunner() + # keep Windows from locking the file with delete=False + with NamedTemporaryFile(prefix=TEMPFILE_PREFIX, mode="w", delete=False) as ssh_config: + ssh_config.write( + dedent("""\ + Host test + Hostname test.example.com + User joe + Port 22222 + IdentityFile ~/.ssh/gateway + """) + ) + ssh_config.flush() + args = ["--list-ssh-config", "--ssh-config-path", ssh_config.name] + result = runner.invoke(click_entrypoint, args=args) + assert "test\n" in result.output + result = runner.invoke(click_entrypoint, args=args + ["--verbose"]) + assert "test : test.example.com\n" in result.output + + # delete=False means we should try to clean up + try: + if os.path.exists(ssh_config.name): + os.remove(ssh_config.name) + except Exception as e: + print(f"An error occurred while attempting to delete the file: {e}") + + +def test_dsn(monkeypatch): + # Setup classes to mock mycli.main.MyCli + class Formatter: + format_name = None + + class Logger: + def debug(self, *args, **args_dict): + pass + + def warning(self, *args, **args_dict): + pass + + class MockMyCli: + config = { + "main": {}, + "alias_dsn": {}, + "connection": { + "default_keepalive_ticks": 0, + }, + } + + def __init__(self, **args): + self.logger = Logger() + self.destructive_warning = False + self.main_formatter = Formatter() + self.redirect_formatter = Formatter() + self.ssl_mode = "auto" + self.my_cnf = {"client": {}, "mysqld": {}} + self.default_keepalive_ticks = 0 + + def connect(self, **args): + MockMyCli.connect_args = args + + def run_query(self, query, new_line=True): + pass + + import mycli.main + + monkeypatch.setattr(mycli.main, "MyCli", MockMyCli) + runner = CliRunner() + + # When a user supplies a DSN as database argument to mycli, + # use these values. + result = runner.invoke(mycli.main.click_entrypoint, args=["mysql://dsn_user:dsn_passwd@dsn_host:1/dsn_database"]) + assert result.exit_code == 0, result.output + " " + str(result.exception) + assert ( + MockMyCli.connect_args["user"] == "dsn_user" + and MockMyCli.connect_args["passwd"] == "dsn_passwd" + and MockMyCli.connect_args["host"] == "dsn_host" + and MockMyCli.connect_args["port"] == 1 + and MockMyCli.connect_args["database"] == "dsn_database" + ) + + MockMyCli.connect_args = None + + # When a use supplies a DSN as database argument to mycli, + # and used command line arguments, use the command line + # arguments. + result = runner.invoke( + mycli.main.click_entrypoint, + args=[ + "mysql://dsn_user:dsn_passwd@dsn_host:2/dsn_database", + "--user", + "arg_user", + "--password", + "arg_password", + "--host", + "arg_host", + "--port", + "3", + "--database", + "arg_database", + ], + ) + assert result.exit_code == 0, result.output + " " + str(result.exception) + assert ( + MockMyCli.connect_args["user"] == "arg_user" + and MockMyCli.connect_args["passwd"] == "arg_password" + and MockMyCli.connect_args["host"] == "arg_host" + and MockMyCli.connect_args["port"] == 3 + and MockMyCli.connect_args["database"] == "arg_database" + ) + + MockMyCli.config = { + "main": {}, + "alias_dsn": {"test": "mysql://alias_dsn_user:alias_dsn_passwd@alias_dsn_host:4/alias_dsn_database"}, + "connection": { + "default_keepalive_ticks": 0, + }, + } + MockMyCli.connect_args = None + + # When a user uses a DSN from the configuration file (alias_dsn), + # use these values. + result = runner.invoke(click_entrypoint, args=["--dsn", "test"]) + assert result.exit_code == 0, result.output + " " + str(result.exception) + assert ( + MockMyCli.connect_args["user"] == "alias_dsn_user" + and MockMyCli.connect_args["passwd"] == "alias_dsn_passwd" + and MockMyCli.connect_args["host"] == "alias_dsn_host" + and MockMyCli.connect_args["port"] == 4 + and MockMyCli.connect_args["database"] == "alias_dsn_database" + ) + + MockMyCli.config = { + "main": {}, + "alias_dsn": {"test": "mysql://alias_dsn_user:alias_dsn_passwd@alias_dsn_host:4/alias_dsn_database"}, + "connection": { + "default_keepalive_ticks": 0, + }, + } + MockMyCli.connect_args = None + + # When a user uses a DSN from the configuration file (alias_dsn) + # and used command line arguments, use the command line arguments. + result = runner.invoke( + click_entrypoint, + args=[ + "--dsn", + "test", + "", + "--user", + "arg_user", + "--password", + "arg_password", + "--host", + "arg_host", + "--port", + "5", + "--database", + "arg_database", + ], + ) + assert result.exit_code == 0, result.output + " " + str(result.exception) + assert ( + MockMyCli.connect_args["user"] == "arg_user" + and MockMyCli.connect_args["passwd"] == "arg_password" + and MockMyCli.connect_args["host"] == "arg_host" + and MockMyCli.connect_args["port"] == 5 + and MockMyCli.connect_args["database"] == "arg_database" + ) + + # Use a DSN without password + result = runner.invoke(mycli.main.click_entrypoint, args=["mysql://dsn_user@dsn_host:6/dsn_database"]) + assert result.exit_code == 0, result.output + " " + str(result.exception) + assert ( + MockMyCli.connect_args["user"] == "dsn_user" + and MockMyCli.connect_args["passwd"] is None + and MockMyCli.connect_args["host"] == "dsn_host" + and MockMyCli.connect_args["port"] == 6 + and MockMyCli.connect_args["database"] == "dsn_database" + ) + + # Use a DSN with query parameters + result = runner.invoke(mycli.main.click_entrypoint, args=["mysql://dsn_user:dsn_passwd@dsn_host:6/dsn_database?ssl_mode=off"]) + assert result.exit_code == 0, result.output + " " + str(result.exception) + assert ( + MockMyCli.connect_args["user"] == "dsn_user" + and MockMyCli.connect_args["passwd"] == "dsn_passwd" + and MockMyCli.connect_args["host"] == "dsn_host" + and MockMyCli.connect_args["port"] == 6 + and MockMyCli.connect_args["database"] == "dsn_database" + and MockMyCli.connect_args["ssl"] is None + ) + + # When a user uses a DSN with query parameters, and also used command line + # arguments, prefer the command line arguments. + MockMyCli.connect_args = None + MockMyCli.config = { + "main": {}, + "alias_dsn": {}, + "connection": { + "default_keepalive_ticks": 0, + }, + } + + # keepalive_ticks as a query parameter + result = runner.invoke(mycli.main.click_entrypoint, args=["mysql://dsn_user:dsn_passwd@dsn_host:6/dsn_database?keepalive_ticks=30"]) + assert result.exit_code == 0, result.output + " " + str(result.exception) + assert MockMyCli.connect_args["keepalive_ticks"] == 30 + + MockMyCli.connect_args = None + + # When a user uses a DSN with query parameters, and also used command line + # arguments, use the command line arguments. + result = runner.invoke( + mycli.main.click_entrypoint, + args=[ + 'mysql://dsn_user:dsn_passwd@dsn_host:6/dsn_database?ssl_mode=off', + '--ssl-mode=on', + ], + ) + assert result.exit_code == 0, result.output + ' ' + str(result.exception) + assert MockMyCli.connect_args['user'] == 'dsn_user' + assert MockMyCli.connect_args['passwd'] == 'dsn_passwd' + assert MockMyCli.connect_args['host'] == 'dsn_host' + assert MockMyCli.connect_args['port'] == 6 + assert MockMyCli.connect_args['database'] == 'dsn_database' + assert MockMyCli.connect_args['ssl']['mode'] == 'on' + + # Accept a literal DSN with the --dsn flag (not only an alias) + result = runner.invoke( + mycli.main.click_entrypoint, + args=[ + '--dsn', + 'mysql://dsn_user:dsn_passwd@dsn_host:6/dsn_database', + ], + ) + assert result.exit_code == 0, result.output + ' ' + str(result.exception) + assert ( + MockMyCli.connect_args['user'] == 'dsn_user' + and MockMyCli.connect_args['passwd'] == 'dsn_passwd' + and MockMyCli.connect_args['host'] == 'dsn_host' + and MockMyCli.connect_args['port'] == 6 + and MockMyCli.connect_args['database'] == 'dsn_database' + ) + + # accept socket as a query parameter + result = runner.invoke( + mycli.main.click_entrypoint, + args=[ + f'mysql://dsn_user:dsn_passwd@{DEFAULT_HOST}/dsn_database?socket=mysql.sock', + ], + ) + assert result.exit_code == 0, result.output + ' ' + str(result.exception) + assert MockMyCli.connect_args['user'] == 'dsn_user' + assert MockMyCli.connect_args['passwd'] == 'dsn_passwd' + assert MockMyCli.connect_args['host'] == DEFAULT_HOST + assert MockMyCli.connect_args['database'] == 'dsn_database' + assert MockMyCli.connect_args['socket'] == 'mysql.sock' + + # accept character_set as a query parameter + result = runner.invoke( + mycli.main.click_entrypoint, + args=[ + f'mysql://dsn_user:dsn_passwd@{DEFAULT_HOST}/dsn_database?character_set=latin1', + ], + ) + assert result.exit_code == 0, result.output + ' ' + str(result.exception) + assert MockMyCli.connect_args['user'] == 'dsn_user' + assert MockMyCli.connect_args['passwd'] == 'dsn_passwd' + assert MockMyCli.connect_args['host'] == DEFAULT_HOST + assert MockMyCli.connect_args['database'] == 'dsn_database' + assert MockMyCli.connect_args['character_set'] == 'latin1' + + # --character_set overrides character_set as a query parameter + result = runner.invoke( + mycli.main.click_entrypoint, + args=[ + f'mysql://dsn_user:dsn_passwd@{DEFAULT_HOST}/dsn_database?character_set=latin1', + '--character-set=utf8mb3', + ], + ) + assert result.exit_code == 0, result.output + ' ' + str(result.exception) + assert MockMyCli.connect_args['user'] == 'dsn_user' + assert MockMyCli.connect_args['passwd'] == 'dsn_passwd' + assert MockMyCli.connect_args['host'] == DEFAULT_HOST + assert MockMyCli.connect_args['database'] == 'dsn_database' + assert MockMyCli.connect_args['character_set'] == 'utf8mb3' + + +def test_mysql_dsn_envvar(monkeypatch): + class Formatter: + format_name = None + + class Logger: + def debug(self, *args, **args_dict): + pass + + def warning(self, *args, **args_dict): + pass + + class MockMyCli: + config = { + 'main': {}, + 'alias_dsn': {}, + 'connection': { + 'default_keepalive_ticks': 0, + }, + } + + def __init__(self, **_args): + self.logger = Logger() + self.destructive_warning = False + self.main_formatter = Formatter() + self.redirect_formatter = Formatter() + self.ssl_mode = 'auto' + self.my_cnf = {'client': {}, 'mysqld': {}} + self.default_keepalive_ticks = 0 + + def connect(self, **args): + MockMyCli.connect_args = args + + def run_query(self, query, new_line=True): + pass + + import mycli.main + + monkeypatch.setattr(mycli.main, 'MyCli', MockMyCli) + monkeypatch.setenv('MYSQL_DSN', 'mysql://dsn_user:dsn_passwd@dsn_host:7/dsn_database') + runner = CliRunner() + + result = runner.invoke(mycli.main.click_entrypoint) + assert result.exit_code == 0, result.output + ' ' + str(result.exception) + assert 'DSN environment variable is deprecated' not in result.output + assert ( + MockMyCli.connect_args['user'] == 'dsn_user' + and MockMyCli.connect_args['passwd'] == 'dsn_passwd' + and MockMyCli.connect_args['host'] == 'dsn_host' + and MockMyCli.connect_args['port'] == 7 + and MockMyCli.connect_args['database'] == 'dsn_database' + ) + + +def test_legacy_dsn_envvar_warns_and_falls_back(monkeypatch): + class Formatter: + format_name = None + + class Logger: + def debug(self, *args, **args_dict): + pass + + def warning(self, *args, **args_dict): + pass + + class MockMyCli: + config = { + 'main': {}, + 'alias_dsn': {}, + 'connection': { + 'default_keepalive_ticks': 0, + }, + } + + def __init__(self, **_args): + self.logger = Logger() + self.destructive_warning = False + self.main_formatter = Formatter() + self.redirect_formatter = Formatter() + self.ssl_mode = 'auto' + self.my_cnf = {'client': {}, 'mysqld': {}} + self.default_keepalive_ticks = 0 + + def connect(self, **args): + MockMyCli.connect_args = args + + def run_query(self, query, new_line=True): + pass + + import mycli.main + + monkeypatch.setattr(mycli.main, 'MyCli', MockMyCli) + monkeypatch.setenv('DSN', 'mysql://dsn_user:dsn_passwd@dsn_host:8/dsn_database') + runner = CliRunner() + + result = runner.invoke(mycli.main.click_entrypoint) + assert result.exit_code == 0, result.output + ' ' + str(result.exception) + assert 'The DSN environment variable is deprecated' in result.output + assert ( + MockMyCli.connect_args['user'] == 'dsn_user' + and MockMyCli.connect_args['passwd'] == 'dsn_passwd' + and MockMyCli.connect_args['host'] == 'dsn_host' + and MockMyCli.connect_args['port'] == 8 + and MockMyCli.connect_args['database'] == 'dsn_database' + ) + + +def test_password_flag_uses_sentinel(monkeypatch): + class Formatter: + format_name = None + + class Logger: + def debug(self, *args, **args_dict): + pass + + def warning(self, *args, **args_dict): + pass + + class MockMyCli: + config = { + 'main': {}, + 'alias_dsn': {}, + 'connection': { + 'default_keepalive_ticks': 0, + }, + } + + def __init__(self, **_args): + self.logger = Logger() + self.destructive_warning = False + self.main_formatter = Formatter() + self.redirect_formatter = Formatter() + self.ssl_mode = 'auto' + self.my_cnf = {'client': {}, 'mysqld': {}} + self.default_keepalive_ticks = 0 + + def connect(self, **args): + MockMyCli.connect_args = args + + def run_query(self, query, new_line=True): + pass + + import mycli.main + + monkeypatch.setattr(mycli.main, 'MyCli', MockMyCli) + runner = CliRunner() + + result = runner.invoke( + mycli.main.click_entrypoint, + args=[ + '--user', + 'user', + '--host', + DEFAULT_HOST, + '--port', + f'{DEFAULT_PORT}', + '--database', + 'database', + '--password', + ], + ) + assert result.exit_code == 0, result.output + ' ' + str(result.exception) + assert MockMyCli.connect_args['passwd'] == EMPTY_PASSWORD_FLAG_SENTINEL + + +def test_password_option_uses_cleartext_value(monkeypatch): + class Formatter: + format_name = None + + class Logger: + def debug(self, *args, **args_dict): + pass + + def warning(self, *args, **args_dict): + pass + + class MockMyCli: + config = { + 'main': {}, + 'alias_dsn': {}, + 'connection': { + 'default_keepalive_ticks': 0, + }, + } + + def __init__(self, **_args): + self.logger = Logger() + self.destructive_warning = False + self.main_formatter = Formatter() + self.redirect_formatter = Formatter() + self.ssl_mode = 'auto' + self.my_cnf = {'client': {}, 'mysqld': {}} + self.default_keepalive_ticks = 0 + + def connect(self, **args): + MockMyCli.connect_args = args + + def run_query(self, query, new_line=True): + pass + + import mycli.main + + monkeypatch.setattr(mycli.main, 'MyCli', MockMyCli) + runner = CliRunner() + + result = runner.invoke( + mycli.main.click_entrypoint, + args=[ + '--user', + 'user', + '--host', + DEFAULT_HOST, + '--port', + f'{DEFAULT_PORT}', + '--database', + 'database', + '--password', + 'cleartext_password', + ], + ) + assert result.exit_code == 0, result.output + ' ' + str(result.exception) + assert MockMyCli.connect_args['passwd'] == 'cleartext_password' + + +def test_password_option_overrides_password_file_and_mysql_pwd(monkeypatch): + class Formatter: + format_name = None + + class Logger: + def debug(self, *args, **args_dict): + pass + + def warning(self, *args, **args_dict): + pass + + class MockMyCli: + config = { + 'main': {}, + 'alias_dsn': {}, + 'connection': { + 'default_keepalive_ticks': 0, + }, + } + + def __init__(self, **_args): + self.logger = Logger() + self.destructive_warning = False + self.main_formatter = Formatter() + self.redirect_formatter = Formatter() + self.ssl_mode = 'auto' + self.my_cnf = {'client': {}, 'mysqld': {}} + self.default_keepalive_ticks = 0 + + def connect(self, **args): + MockMyCli.connect_args = args + + def run_query(self, query, new_line=True): + pass + + import mycli.main + + monkeypatch.setattr(mycli.main, 'MyCli', MockMyCli) + monkeypatch.setenv('MYSQL_PWD', 'env_password') + runner = CliRunner() + + with NamedTemporaryFile(prefix=TEMPFILE_PREFIX, mode='w', delete=False) as password_file: + password_file.write('file_password\n') + password_file.flush() + + try: + result = runner.invoke( + mycli.main.click_entrypoint, + args=[ + '--user', + 'user', + '--host', + DEFAULT_HOST, + '--port', + f'{DEFAULT_PORT}', + '--database', + 'database', + '--password', + 'option_password', + '--password-file', + password_file.name, + ], + ) + assert result.exit_code == 0, result.output + ' ' + str(result.exception) + assert MockMyCli.connect_args['passwd'] == 'option_password' + finally: + os.remove(password_file.name) + + +def test_password_file_option_reads_password(monkeypatch): + class Formatter: + format_name = None + + class Logger: + def debug(self, *args, **args_dict): + pass + + def warning(self, *args, **args_dict): + pass + + class MockMyCli: + config = { + 'main': {}, + 'alias_dsn': {}, + 'connection': { + 'default_keepalive_ticks': 0, + }, + } + + def __init__(self, **_args): + self.logger = Logger() + self.destructive_warning = False + self.main_formatter = Formatter() + self.redirect_formatter = Formatter() + self.ssl_mode = 'auto' + self.my_cnf = {'client': {}, 'mysqld': {}} + self.default_keepalive_ticks = 0 + + def connect(self, **args): + MockMyCli.connect_args = args + + def run_query(self, query, new_line=True): + pass + + import mycli.main + + monkeypatch.setattr(mycli.main, 'MyCli', MockMyCli) + runner = CliRunner() + + with NamedTemporaryFile(prefix=TEMPFILE_PREFIX, mode='w', delete=False) as password_file: + password_file.write('file_password\nsecond line ignored\n') + password_file.flush() + + try: + result = runner.invoke( + mycli.main.click_entrypoint, + args=[ + '--user', + 'user', + '--host', + DEFAULT_HOST, + '--port', + f'{DEFAULT_PORT}', + '--database', + 'database', + '--password-file', + password_file.name, + ], + ) + assert result.exit_code == 0, result.output + ' ' + str(result.exception) + assert MockMyCli.connect_args['passwd'] == 'file_password' + finally: + os.remove(password_file.name) + + +def test_password_file_option_missing_file(): + runner = CliRunner() + missing_path = 'definitely_missing_password_file.txt' + + result = runner.invoke( + click_entrypoint, + args=[ + '--password-file', + missing_path, + ], + ) + + assert result.exit_code == 1 + assert f"Password file '{missing_path}' not found" in result.output + + +def test_username_option_and_mysql_user_envvar(monkeypatch): + class Formatter: + format_name = None + + class Logger: + def debug(self, *args, **args_dict): + pass + + def warning(self, *args, **args_dict): + pass + + class MockMyCli: + config = { + 'main': {}, + 'alias_dsn': {}, + 'connection': { + 'default_keepalive_ticks': 0, + }, + } + + def __init__(self, **_args): + self.logger = Logger() + self.destructive_warning = False + self.main_formatter = Formatter() + self.redirect_formatter = Formatter() + self.ssl_mode = 'auto' + self.my_cnf = {'client': {}, 'mysqld': {}} + self.default_keepalive_ticks = 0 + + def connect(self, **args): + MockMyCli.connect_args = args + + def run_query(self, query, new_line=True): + pass + + import mycli.main + + monkeypatch.setattr(mycli.main, 'MyCli', MockMyCli) + runner = CliRunner() + + result = runner.invoke( + mycli.main.click_entrypoint, + args=[ + '--username', + 'option_user', + '--host', + DEFAULT_HOST, + '--port', + f'{DEFAULT_PORT}', + '--database', + 'database', + ], + ) + assert result.exit_code == 0, result.output + ' ' + str(result.exception) + assert MockMyCli.connect_args['user'] == 'option_user' + + MockMyCli.connect_args = None + monkeypatch.setenv('MYSQL_USER', 'env_user') + result = runner.invoke( + mycli.main.click_entrypoint, + args=[ + '--host', + DEFAULT_HOST, + '--port', + f'{DEFAULT_PORT}', + '--database', + 'database', + ], + ) + assert result.exit_code == 0, result.output + ' ' + str(result.exception) + assert MockMyCli.connect_args['user'] == 'env_user' + + +def test_host_option_and_mysql_host_envvar(monkeypatch): + class Formatter: + format_name = None + + class Logger: + def debug(self, *args, **args_dict): + pass + + def warning(self, *args, **args_dict): + pass + + class MockMyCli: + config = { + 'main': {}, + 'alias_dsn': {}, + 'connection': { + 'default_keepalive_ticks': 0, + }, + } + + def __init__(self, **_args): + self.logger = Logger() + self.destructive_warning = False + self.main_formatter = Formatter() + self.redirect_formatter = Formatter() + self.ssl_mode = 'auto' + self.my_cnf = {'client': {}, 'mysqld': {}} + self.default_keepalive_ticks = 0 + + def connect(self, **args): + MockMyCli.connect_args = args + + def run_query(self, query, new_line=True): + pass + + import mycli.main + + monkeypatch.setattr(mycli.main, 'MyCli', MockMyCli) + runner = CliRunner() + + result = runner.invoke( + mycli.main.click_entrypoint, + args=[ + '--host', + 'option_host', + '--port', + f'{DEFAULT_PORT}', + '--database', + 'database', + ], + ) + assert result.exit_code == 0, result.output + ' ' + str(result.exception) + assert MockMyCli.connect_args['host'] == 'option_host' + + MockMyCli.connect_args = None + monkeypatch.setenv('MYSQL_HOST', 'env_host') + result = runner.invoke( + mycli.main.click_entrypoint, + args=[ + '--port', + f'{DEFAULT_PORT}', + '--database', + 'database', + ], + ) + assert result.exit_code == 0, result.output + ' ' + str(result.exception) + assert MockMyCli.connect_args['host'] == 'env_host' + + +def test_hostname_option_alias(monkeypatch): + class Formatter: + format_name = None + + class Logger: + def debug(self, *args, **args_dict): + pass + + def warning(self, *args, **args_dict): + pass + + class MockMyCli: + config = { + 'main': {}, + 'alias_dsn': {}, + 'connection': { + 'default_keepalive_ticks': 0, + }, + } + + def __init__(self, **_args): + self.logger = Logger() + self.destructive_warning = False + self.main_formatter = Formatter() + self.redirect_formatter = Formatter() + self.ssl_mode = 'auto' + self.my_cnf = {'client': {}, 'mysqld': {}} + self.default_keepalive_ticks = 0 + + def connect(self, **args): + MockMyCli.connect_args = args + + def run_query(self, query, new_line=True): + pass + + import mycli.main + + monkeypatch.setattr(mycli.main, 'MyCli', MockMyCli) + runner = CliRunner() + + result = runner.invoke( + mycli.main.click_entrypoint, + args=[ + '--hostname', + 'alias_host', + '--port', + f'{DEFAULT_PORT}', + '--database', + 'database', + ], + ) + assert result.exit_code == 0 + assert MockMyCli.connect_args['host'] == 'alias_host' + + +def test_port_option_and_mysql_tcp_port_envvar(monkeypatch): + class Formatter: + format_name = None + + class Logger: + def debug(self, *args, **args_dict): + pass + + def warning(self, *args, **args_dict): + pass + + class MockMyCli: + config = { + 'main': {}, + 'alias_dsn': {}, + 'connection': { + 'default_keepalive_ticks': 0, + }, + } + + def __init__(self, **_args): + self.logger = Logger() + self.destructive_warning = False + self.main_formatter = Formatter() + self.redirect_formatter = Formatter() + self.ssl_mode = 'auto' + self.my_cnf = {'client': {}, 'mysqld': {}} + self.default_keepalive_ticks = 0 + + def connect(self, **args): + MockMyCli.connect_args = args + + def run_query(self, query, new_line=True): + pass + + import mycli.main + + monkeypatch.setattr(mycli.main, 'MyCli', MockMyCli) + runner = CliRunner() + + result = runner.invoke( + mycli.main.click_entrypoint, + args=[ + '--host', + DEFAULT_HOST, + '--port', + '12345', + '--database', + 'database', + ], + ) + assert result.exit_code == 0, result.output + ' ' + str(result.exception) + assert MockMyCli.connect_args['port'] == 12345 + + MockMyCli.connect_args = None + monkeypatch.setenv('MYSQL_TCP_PORT', '23456') + result = runner.invoke( + mycli.main.click_entrypoint, + args=[ + '--host', + DEFAULT_HOST, + '--database', + 'database', + ], + ) + assert result.exit_code == 0, result.output + ' ' + str(result.exception) + assert MockMyCli.connect_args['port'] == 23456 + + +def test_socket_option_and_mysql_unix_socket_envvar(monkeypatch): + class Formatter: + format_name = None + + class Logger: + def debug(self, *args, **args_dict): + pass + + def warning(self, *args, **args_dict): + pass + + class MockMyCli: + config = { + 'main': {}, + 'alias_dsn': {}, + 'connection': { + 'default_keepalive_ticks': 0, + }, + } + + def __init__(self, **_args): + self.logger = Logger() + self.destructive_warning = False + self.main_formatter = Formatter() + self.redirect_formatter = Formatter() + self.ssl_mode = 'auto' + self.my_cnf = {'client': {}, 'mysqld': {}} + self.default_keepalive_ticks = 0 + + def connect(self, **args): + MockMyCli.connect_args = args + + def run_query(self, query, new_line=True): + pass + + import mycli.main + + monkeypatch.setattr(mycli.main, 'MyCli', MockMyCli) + runner = CliRunner() + + result = runner.invoke( + mycli.main.click_entrypoint, + args=[ + '--socket', + 'option.sock', + '--database', + 'database', + ], + ) + assert result.exit_code == 0, result.output + ' ' + str(result.exception) + assert MockMyCli.connect_args['socket'] == 'option.sock' + + MockMyCli.connect_args = None + monkeypatch.setenv('MYSQL_UNIX_SOCKET', 'env.sock') + result = runner.invoke( + mycli.main.click_entrypoint, + args=[ + '--database', + 'database', + ], + ) + assert result.exit_code == 0, result.output + ' ' + str(result.exception) + assert MockMyCli.connect_args['socket'] == 'env.sock' + + +def test_mysql_user_envvar_overrides_dsn_resolution(monkeypatch): + class Formatter: + format_name = None + + class Logger: + def debug(self, *args, **args_dict): + pass + + def warning(self, *args, **args_dict): + pass + + class MockMyCli: + config = { + 'main': {}, + 'alias_dsn': { + 'prod': 'mysql://alias_user:alias_password@alias_host:4/alias_database', + }, + 'connection': { + 'default_keepalive_ticks': 0, + }, + } + + def __init__(self, **_args): + self.logger = Logger() + self.destructive_warning = False + self.main_formatter = Formatter() + self.redirect_formatter = Formatter() + self.ssl_mode = 'auto' + self.my_cnf = {'client': {}, 'mysqld': {}} + self.default_keepalive_ticks = 0 + + def connect(self, **args): + MockMyCli.connect_args = args + + def run_query(self, query, new_line=True): + pass + + import mycli.main + + monkeypatch.setattr(mycli.main, 'MyCli', MockMyCli) + monkeypatch.setenv('MYSQL_USER', 'env_user') + runner = CliRunner() + + result = runner.invoke(mycli.main.click_entrypoint, args=['prod']) + assert result.exit_code == 0, result.output + ' ' + str(result.exception) + assert MockMyCli.connect_args['user'] == 'env_user' + assert MockMyCli.connect_args['passwd'] is None + assert MockMyCli.connect_args['host'] is None + assert MockMyCli.connect_args['port'] is None + assert MockMyCli.connect_args['database'] == 'prod' + + MockMyCli.connect_args = None + result = runner.invoke(mycli.main.click_entrypoint, args=['mysql://dsn_user:dsn_passwd@dsn_host:6/dsn_database']) + assert result.exit_code == 0, result.output + ' ' + str(result.exception) + assert ( + MockMyCli.connect_args['user'] == 'env_user' + and MockMyCli.connect_args['passwd'] == 'dsn_passwd' + and MockMyCli.connect_args['host'] == 'dsn_host' + and MockMyCli.connect_args['port'] == 6 + and MockMyCli.connect_args['database'] == 'dsn_database' + ) + + +def test_ssh_config(monkeypatch): + # Setup classes to mock mycli.main.MyCli + class Formatter: + format_name = None + + class Logger: + def debug(self, *args, **args_dict): + pass + + def warning(self, *args, **args_dict): + pass + + class MockMyCli: + config = { + "main": {}, + "alias_dsn": {}, + "connection": { + "default_keepalive_ticks": 0, + }, + } + + def __init__(self, **args): + self.logger = Logger() + self.destructive_warning = False + self.main_formatter = Formatter() + self.redirect_formatter = Formatter() + self.ssl_mode = "auto" + self.my_cnf = {"client": {}, "mysqld": {}} + self.default_keepalive_ticks = 0 + + def connect(self, **args): + MockMyCli.connect_args = args + + def run_query(self, query, new_line=True): + pass + + import mycli.main + + monkeypatch.setattr(mycli.main, "MyCli", MockMyCli) + runner = CliRunner() + + # Setup temporary configuration + # keep Windows from locking the file with delete=False + with NamedTemporaryFile(prefix=TEMPFILE_PREFIX, mode="w", delete=False) as ssh_config: + ssh_config.write( + dedent("""\ + Host test + Hostname test.example.com + User joe + Port 22222 + IdentityFile ~/.ssh/gateway + """) + ) + ssh_config.flush() + + # When a user supplies a ssh config. + result = runner.invoke(mycli.main.click_entrypoint, args=["--ssh-config-path", ssh_config.name, "--ssh-config-host", "test"]) + assert result.exit_code == 0, result.output + " " + str(result.exception) + assert ( + MockMyCli.connect_args["ssh_user"] == "joe" + and MockMyCli.connect_args["ssh_host"] == "test.example.com" + and MockMyCli.connect_args["ssh_port"] == 22222 + and MockMyCli.connect_args["ssh_key_filename"] == os.path.expanduser("~") + "/.ssh/gateway" + ) + + # When a user supplies a ssh config host as argument to mycli, + # and used command line arguments, use the command line + # arguments. + result = runner.invoke( + mycli.main.click_entrypoint, + args=[ + "--ssh-config-path", + ssh_config.name, + "--ssh-config-host", + "test", + "--ssh-user", + "arg_user", + "--ssh-host", + "arg_host", + "--ssh-port", + "3", + "--ssh-key-filename", + "/path/to/key", + ], + ) + assert result.exit_code == 0, result.output + " " + str(result.exception) + assert ( + MockMyCli.connect_args["ssh_user"] == "arg_user" + and MockMyCli.connect_args["ssh_host"] == "arg_host" + and MockMyCli.connect_args["ssh_port"] == 3 + and MockMyCli.connect_args["ssh_key_filename"] == "/path/to/key" + ) + + # delete=False means we should try to clean up + try: + if os.path.exists(ssh_config.name): + os.remove(ssh_config.name) + except Exception as e: + print(f"An error occurred while attempting to delete the file: {e}") + + +@dbtest +def test_init_command_arg(executor): + init_command = "set sql_select_limit=1000" + sql = 'show variables like "sql_select_limit";' + runner = CliRunner() + result = runner.invoke(click_entrypoint, args=CLI_ARGS + ["--init-command", init_command], input=sql) + + expected = "sql_select_limit\t1000\n" + assert result.exit_code == 0 + assert expected in result.output + + +@dbtest +def test_init_command_multiple_arg(executor): + init_command = "set sql_select_limit=2000; set max_join_size=20000" + sql = 'show variables like "sql_select_limit";\nshow variables like "max_join_size"' + runner = CliRunner() + result = runner.invoke(click_entrypoint, args=CLI_ARGS + ["--init-command", init_command], input=sql) + + expected_sql_select_limit = "sql_select_limit\t2000\n" + expected_max_join_size = "max_join_size\t20000\n" + + assert result.exit_code == 0 + assert expected_sql_select_limit in result.output + assert expected_max_join_size in result.output + + +@dbtest +def test_global_init_commands(executor): + """Tests that global init-commands from config are executed by default.""" + # The global init-commands section in test/myclirc sets sql_select_limit=9999 + sql = 'show variables like "sql_select_limit";' + runner = CliRunner() + result = runner.invoke(click_entrypoint, args=CLI_ARGS, input=sql) + expected = "sql_select_limit\t9999\n" + assert result.exit_code == 0 + assert expected in result.output + + +@dbtest +def test_execute_with_logfile(executor): + """Test that --execute combines with --logfile""" + sql = 'select 1' + runner = CliRunner() + + with NamedTemporaryFile(prefix=TEMPFILE_PREFIX, mode="w", delete=False) as logfile: + result = runner.invoke(mycli.main.click_entrypoint, args=CLI_ARGS + ["--logfile", logfile.name, "--execute", sql]) + assert result.exit_code == 0 + + assert os.path.getsize(logfile.name) > 0 + + try: + if os.path.exists(logfile.name): + os.remove(logfile.name) + except Exception as e: + print(f"An error occurred while attempting to delete the file: {e}") + + +@dbtest +def test_execute_with_short_logfile_option(executor): + """Test that --execute combines with -l""" + sql = 'select 1' + runner = CliRunner() + + with NamedTemporaryFile(prefix=TEMPFILE_PREFIX, mode="w", delete=False) as logfile: + result = runner.invoke(mycli.main.click_entrypoint, args=CLI_ARGS + ["-l", logfile.name, "--execute", sql]) + assert result.exit_code == 0 + + assert os.path.getsize(logfile.name) > 0 + + try: + if os.path.exists(logfile.name): + os.remove(logfile.name) + except Exception as e: + print(f"An error occurred while attempting to delete the file: {e}") + + +def noninteractive_mock_mycli(monkeypatch): + class Formatter: + format_name = None + + class Logger: + def debug(self, *args, **args_dict): + pass + + def error(self, *args, **args_dict): + pass + + def warning(self, *args, **args_dict): + pass + + class MockMyCli: + connect_calls = 0 + ran_queries = [] + + config = { + 'main': { + 'use_keyring': 'False', + 'my_cnf_transition_done': 'True', + }, + 'connection': {}, + } + + def __init__(self, **_args): + self.logger = Logger() + self.destructive_warning = False + self.main_formatter = Formatter() + self.redirect_formatter = Formatter() + self.ssl_mode = 'auto' + self.my_cnf = {'client': {}, 'mysqld': {}} + self.default_keepalive_ticks = 0 + self.config_without_package_defaults = {'connection': {}} + + def connect(self, **_args): + MockMyCli.connect_calls += 1 + + def run_query(self, query, checkpoint=None, new_line=True): + MockMyCli.ran_queries.append(query) + + def run_cli(self): + raise AssertionError('should not enter interactive cli') + + def close(self): + pass + + import mycli.main + import mycli.main_modes.batch + + monkeypatch.setattr(mycli.main, 'MyCli', MockMyCli) + return mycli.main, mycli.main_modes.batch, MockMyCli + + +def test_execute_arg_warns_about_ignoring_stdin(monkeypatch): + mycli_main, mycli_main_batch, MockMyCli = noninteractive_mock_mycli(monkeypatch) + runner = CliRunner() + + # the test env should make sure stdin is not a TTY + result = runner.invoke(mycli_main.click_entrypoint, args=['--execute', 'select 1;']) + + # this exit_code is as written currently, but a debatable choice, + # since there was a warning + assert result.exit_code == 0 + assert 'Ignoring STDIN' in result.output + + +def test_verbose_and_quiet_are_incompatible() -> None: + runner = CliRunner() + + result = runner.invoke(click_entrypoint, args=['--verbose', '--quiet']) + + assert result.exit_code == 1 + assert 'incompatible.' in result.output + + +def test_quiet_sets_negative_cli_verbosity(monkeypatch: pytest.MonkeyPatch) -> None: + dummy_class = make_dummy_mycli_class( + config={ + 'main': {'use_keyring': 'false', 'my_cnf_transition_done': 'true'}, + 'connection': {'default_keepalive_ticks': 0}, + 'alias_dsn': {}, + } + ) + monkeypatch.setattr(main, 'MyCli', dummy_class) + monkeypatch.setattr(main.sys, 'stdin', SimpleNamespace(isatty=lambda: True)) + + cli_args = main.CliArgs() + cli_args.quiet = True + + call_click_entrypoint_direct(cli_args) + + dummy = dummy_class.last_instance + assert dummy is not None + assert dummy.init_kwargs['cli_verbosity'] == -1 + + +def test_resume_requires_batch() -> None: + runner = CliRunner() + + result = runner.invoke(click_entrypoint, args=['--checkpoint', os.devnull, '--resume']) + + assert result.exit_code == 1 + assert 'Error:' in result.output + + +def test_resume_requires_checkpoint() -> None: + runner = CliRunner() + + result = runner.invoke(click_entrypoint, args=['--batch', os.devnull, '--resume']) + + assert result.exit_code == 1 + assert 'Error:' in result.output + + +def test_execute_arg_supersedes_batch_file(monkeypatch): + mycli_main, mycli_main_batch, MockMyCli = noninteractive_mock_mycli(monkeypatch) + runner = CliRunner() + + with NamedTemporaryFile(prefix=TEMPFILE_PREFIX, mode='w', delete=False) as batch_file: + batch_file.write('select 2;\n') + batch_file.flush() + + try: + result = runner.invoke( + mycli_main.click_entrypoint, + args=['--execute', 'select 1;', '--batch', batch_file.name], + ) + # this exit_code is as written currently, but a debatable choice, + # since there was a warning + assert result.exit_code == 0 + assert MockMyCli.ran_queries == ['select 1;'] + finally: + os.remove(batch_file.name) + + +@dbtest +def test_null_string_config(monkeypatch): + monkeypatch.setattr(MyCli, 'system_config_files', []) + monkeypatch.setattr(MyCli, 'pwd_config_file', os.devnull) + runner = CliRunner() + # keep Windows from locking the file with delete=False + with NamedTemporaryFile(mode='w', delete=False) as myclirc: + myclirc.write( + dedent("""\ + [main] + null_string = + """) + ) + myclirc.flush() + args = CLI_ARGS_WITHOUT_DB + ['--myclirc', myclirc.name, '--format=table', '--execute', 'SELECT NULL'] + result = runner.invoke(mycli.main.click_entrypoint, args=args) + assert '' in result.output + assert '' not in result.output + + # delete=False means we should try to clean up + try: + if os.path.exists(myclirc.name): + os.remove(myclirc.name) + except Exception as e: + print(f'An error occurred while attempting to delete the file: {e}') + + +def test_change_prompt_format_requires_argument() -> None: + cli = make_bare_mycli() + assert main.MyCli.change_prompt_format(cli, '')[0].status == 'Missing required argument, format.' + + +def test_change_prompt_format_updates_prompt() -> None: + cli = make_bare_mycli() + assert main.MyCli.change_prompt_format(cli, '\\u@\\h> ')[0].status == 'Changed prompt format to \\u@\\h> ' + + +def test_output_timing_logs_and_prints_with_warning_style(monkeypatch: pytest.MonkeyPatch) -> None: + cli = make_bare_mycli() + timings_logged: list[str] = [] + cli.log_output = lambda text: timings_logged.append(text) # type: ignore[assignment] + printed: list[tuple[Any, Any]] = [] + monkeypatch.setattr(prompt_toolkit, 'print_formatted_text', lambda text, style=None: printed.append((text, style))) + main.MyCli.output_timing(cli, 'Time: 1.000s', is_warnings_style=True) + assert timings_logged == ['Time: 1.000s'] + assert printed[-1][1] == cli.ptoolkit_style + + +def test_run_cli_delegates_to_main_repl(monkeypatch: pytest.MonkeyPatch) -> None: + cli = make_bare_mycli() + run_cli_calls: list[Any] = [] + monkeypatch.setattr(main, 'main_repl', lambda target: run_cli_calls.append(target)) + main.MyCli.run_cli(cli) + assert run_cli_calls == [cli] + + +def test_get_output_margin_uses_prompt_session_render_counter(monkeypatch: pytest.MonkeyPatch) -> None: + cli = make_bare_mycli() + render_counters: list[int] = [] + cli.prompt_lines = 0 + cli.get_reserved_space = lambda: 2 # type: ignore[assignment] + cli.prompt_session = cast( + Any, + SimpleNamespace(app=SimpleNamespace(render_counter=7)), + ) + + def fake_render_prompt_string(mycli: Any, string: str, render_counter: int) -> FormattedText: + render_counters.append(render_counter) + return to_formatted_text('line1\nline2') + + monkeypatch.setattr(repl_mode, 'render_prompt_string', fake_render_prompt_string) + monkeypatch.setattr(main.special, 'is_timing_enabled', lambda: False) + assert main.MyCli.get_output_margin(cli, 'ok') == 5 + assert render_counters == [7] + + +def test_on_completions_refreshed_updates_completer_and_invalidates_prompt() -> None: + cli = make_bare_mycli() + entered_lock = {'count': 0} + invalidated: list[bool] = [] + cli._completer_lock = cast(Any, ReusableLock(lambda: entered_lock.__setitem__('count', entered_lock['count'] + 1))) + cli.prompt_session = cast(Any, SimpleNamespace(app=SimpleNamespace(invalidate=lambda: invalidated.append(True)))) + cli.completer = cast(Any, SimpleNamespace(dbmetadata={})) + copy_calls: list[tuple[Any, str | None]] = [] + new_completer = cast( + Any, + SimpleNamespace( + dbname='current', + get_completions=lambda document, event: ['done'], + copy_other_schemas_from=lambda source, exclude: copy_calls.append((source, exclude)), + ), + ) + main.MyCli._on_completions_refreshed(cli, new_completer) + assert cli.completer is new_completer + assert invalidated == [True] + assert entered_lock['count'] == 1 + assert copy_calls == [(copy_calls[0][0], 'current')] + + +def test_click_entrypoint_callback_covers_dsn_list_init_commands(monkeypatch: pytest.MonkeyPatch) -> None: + dummy_class = make_dummy_mycli_class( + config={ + 'main': {'use_keyring': 'false', 'my_cnf_transition_done': 'true'}, + 'connection': {'default_keepalive_ticks': 0}, + 'alias_dsn': {'prod': 'mysql://u:p@h/db'}, + 'alias_dsn.init-commands': {'prod': ['set a=1', 'set b=2']}, + } + ) + monkeypatch.setattr(main, 'MyCli', dummy_class) + monkeypatch.setattr(main.sys, 'stdin', SimpleNamespace(isatty=lambda: True)) + monkeypatch.setattr(main.sys.stderr, 'isatty', lambda: True) + + cli_args = main.CliArgs() + cli_args.dsn = 'prod' + cli_args.init_command = 'set c=3' + call_click_entrypoint_direct(cli_args) + + dummy = dummy_class.last_instance + assert dummy is not None + assert dummy.connect_calls[-1]['init_command'] == 'set a=1; set b=2; set c=3' + + +def test_click_entrypoint_callback_uses_batch_with_progress_path(monkeypatch: pytest.MonkeyPatch) -> None: + dummy_class = make_dummy_mycli_class( + config={ + 'main': {'use_keyring': 'false', 'my_cnf_transition_done': 'true'}, + 'connection': {'default_keepalive_ticks': 0}, + 'alias_dsn': {}, + } + ) + monkeypatch.setattr(main, 'MyCli', dummy_class) + monkeypatch.setattr(main.sys, 'stdin', SimpleNamespace(isatty=lambda: True)) + monkeypatch.setattr(main.sys.stderr, 'isatty', lambda: True) + monkeypatch.setattr(main, 'main_batch_with_progress_bar', lambda mycli, cli_args: 12) + + cli_args = main.CliArgs() + cli_args.batch = 'queries.sql' + cli_args.progress = True + with pytest.raises(SystemExit) as excinfo: + call_click_entrypoint_direct(cli_args) + assert excinfo.value.code == 12 + + +def test_click_entrypoint_callback_uses_batch_without_progress_path(monkeypatch: pytest.MonkeyPatch) -> None: + dummy_class = make_dummy_mycli_class( + config={ + 'main': {'use_keyring': 'false', 'my_cnf_transition_done': 'true'}, + 'connection': {'default_keepalive_ticks': 0}, + 'alias_dsn': {}, + } + ) + monkeypatch.setattr(main, 'MyCli', dummy_class) + monkeypatch.setattr(main.sys, 'stdin', SimpleNamespace(isatty=lambda: True)) + monkeypatch.setattr(main.sys.stderr, 'isatty', lambda: True) + monkeypatch.setattr(main, 'main_batch_without_progress_bar', lambda mycli, cli_args: 13) + + cli_args = main.CliArgs() + cli_args.batch = 'queries.sql' + cli_args.progress = False + with pytest.raises(SystemExit) as excinfo: + call_click_entrypoint_direct(cli_args) + assert excinfo.value.code == 13 + + +def test_click_entrypoint_callback_covers_mycnf_underscore_fallback(monkeypatch: pytest.MonkeyPatch) -> None: + click_lines: list[str] = [] + monkeypatch.setattr(click, 'secho', lambda message='', **kwargs: click_lines.append(str(message))) + monkeypatch.setattr(main.sys, 'stdin', SimpleNamespace(isatty=lambda: True)) + monkeypatch.setattr(main.sys.stderr, 'isatty', lambda: False) + + dummy_class = make_dummy_mycli_class( + config={ + 'main': {'use_keyring': 'false', 'my_cnf_transition_done': 'false'}, + 'connection': {'default_keepalive_ticks': 0}, + 'alias_dsn': {}, + }, + my_cnf={'client': {'ssl_ca': '/tmp/ca.pem'}, 'mysqld': {}}, + config_without_package_defaults={'main': {}}, + ) + monkeypatch.setattr(main, 'MyCli', dummy_class) + + call_click_entrypoint_direct(main.CliArgs()) + assert any('ssl-ca = /tmp/ca.pem' in line for line in click_lines) + + +def test_format_sqlresult_uses_redirect_formatter_when_redirected() -> None: + cli = make_bare_mycli() + cli.main_formatter = DummyFormatter() + cli.redirect_formatter = DummyFormatter() + + result = SQLResult(header=['id'], rows=[(1,)], status='ok') + assert list(main.MyCli.format_sqlresult(cli, result, is_redirected=True)) == ['plain output'] + + assert cli.main_formatter.calls == [] + assert len(cli.redirect_formatter.calls) == 1 + + +def test_format_sqlresult_materializes_cursor_rows_when_width_is_limited(monkeypatch: pytest.MonkeyPatch) -> None: + cli = make_bare_mycli() + cli.main_formatter = DummyFormatter() + rows = FakeCursorBase(rows=[(1,)], rowcount=1, description=[('id', 3)]) + monkeypatch.setattr(output_module, 'Cursor', FakeCursorBase) + + result = SQLResult(header=['id'], rows=cast(Any, rows), status='ok') + list(main.MyCli.format_sqlresult(cli, result, max_width=100)) + + formatted_rows = cli.main_formatter.calls[-1][0][0] + assert formatted_rows == [(1,)] + + +def test_format_sqlresult_appends_postamble() -> None: + cli = make_bare_mycli() + result = SQLResult(header=['id'], rows=[(1,)], status='ok', postamble='done') + + assert list(main.MyCli.format_sqlresult(cli, result))[-1] == 'done' + + +def test_get_last_query_returns_latest_query() -> None: + cli = make_bare_mycli() + cli.query_history = [main.Query('select 1', True, False)] + + assert main.MyCli.get_last_query(cli) == 'select 1' + + +def test_connect_reports_expired_password_login_error(monkeypatch: pytest.MonkeyPatch) -> None: + cli = make_bare_mycli() + cli.my_cnf = {'client': {}, 'mysqld': {}} + cli.config_without_package_defaults = {'connection': {}} + cli.config = {'connection': {}, 'main': {}} + cli.logger = cast(Any, DummyLogger()) + echo_calls: list[str] = [] + cli.echo = lambda message, **kwargs: echo_calls.append(str(message)) # type: ignore[assignment] + monkeypatch.setattr(main, 'WIN', False) + monkeypatch.setattr(main, 'str_to_bool', lambda value: False) + + class ExpiredPasswordSQLExecute(RecordingSQLExecute): + calls: list[dict[str, Any]] = [] + side_effects: list[Any] = [pymysql.OperationalError(main.ER_MUST_CHANGE_PASSWORD_LOGIN, 'must change password')] + + monkeypatch.setattr(main, 'SQLExecute', ExpiredPasswordSQLExecute) + + with pytest.raises(SystemExit): + main.MyCli.connect(cli, host='db', port=3307) + + assert any('password has expired' in message for message in echo_calls) + + +def test_connect_sets_cli_sandbox_mode_when_sqlexecute_enters_sandbox(monkeypatch: pytest.MonkeyPatch) -> None: + cli = make_bare_mycli() + cli.my_cnf = {'client': {}, 'mysqld': {}} + cli.config_without_package_defaults = {'connection': {}} + cli.config = {'connection': {}, 'main': {}} + cli.logger = cast(Any, DummyLogger()) + echo_calls: list[str] = [] + cli.echo = lambda message, **kwargs: echo_calls.append(str(message)) # type: ignore[assignment] + monkeypatch.setattr(main, 'WIN', False) + monkeypatch.setattr(main, 'str_to_bool', lambda value: False) + + class SandboxSQLExecute(RecordingSQLExecute): + calls: list[dict[str, Any]] = [] + side_effects: list[Any] = [] + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.sandbox_mode = True + + monkeypatch.setattr(main, 'SQLExecute', SandboxSQLExecute) + + main.MyCli.connect(cli, host='db', port=3307) + + assert cli.sandbox_mode is True + assert any('password has expired' in message for message in echo_calls) diff --git a/test/pytests/test_main_modes_batch.py b/test/pytests/test_main_modes_batch.py new file mode 100644 index 00000000..9d7fd9a2 --- /dev/null +++ b/test/pytests/test_main_modes_batch.py @@ -0,0 +1,733 @@ +from __future__ import annotations + +from dataclasses import dataclass +from io import TextIOWrapper +import os +from pathlib import Path +import sys +from tempfile import NamedTemporaryFile +from types import SimpleNamespace +from typing import Any, Literal, cast + +from click.testing import CliRunner +import pytest + +import mycli.main_modes.batch as batch_mode +import test.pytests.test_main as test_main_module +import test.utils as test_utils + +noninteractive_mock_mycli = cast(Any, test_main_module).noninteractive_mock_mycli +TEMPFILE_PREFIX = cast(str, cast(Any, test_utils).TEMPFILE_PREFIX) + + +@dataclass +class DummyCliArgs: + format: str = 'tsv' + noninteractive: bool = True + throttle: float = 0.0 + checkpoint: str | TextIOWrapper | None = None + batch: str | None = None + resume: bool = False + + +@dataclass +class DummyFormatter: + format_name: str | None = None + + +class DummyLogger: + def __init__(self) -> None: + self.warning_messages: list[str] = [] + + def warning(self, message: str) -> None: + self.warning_messages.append(message) + + +class DummyMyCli: + def __init__(self, destructive_warning: bool = False, run_query_error: Exception | None = None) -> None: + self.main_formatter = DummyFormatter() + self.destructive_warning = destructive_warning + self.destructive_keywords = ('drop',) + self.logger = DummyLogger() + self.run_query_error = run_query_error + self.ran_queries: list[tuple[str, str | TextIOWrapper | None, bool]] = [] + + def run_query(self, query: str, checkpoint: str | TextIOWrapper | None = None, new_line: bool = True) -> None: + if self.run_query_error is not None: + raise self.run_query_error + self.ran_queries.append((query, checkpoint, new_line)) + + +class DummyFile: + def __init__(self, name: str) -> None: + self.name = name + self.closed = False + + def close(self) -> None: + self.closed = True + + +class DummyProgressBar: + calls: list[list[int]] = [] + + def __init__(self, *args, **kwargs) -> None: + pass + + def __enter__(self) -> 'DummyProgressBar': + return self + + def __exit__(self, exc_type, exc, tb) -> Literal[False]: + return False + + def __call__(self, iterable) -> list[int]: + values = list(iterable) + DummyProgressBar.calls.append(values) + return values + + +def dispatch_batch_statements( + mycli: DummyMyCli, + cli_args: DummyCliArgs, + statements: str, + batch_counter: int, +) -> None: + batch_mode.dispatch_batch_statements(cast(Any, mycli), cast(Any, cli_args), statements, batch_counter) + + +def main_batch_with_progress_bar(mycli: DummyMyCli, cli_args: DummyCliArgs) -> int: + return batch_mode.main_batch_with_progress_bar(cast(Any, mycli), cast(Any, cli_args)) + + +def main_batch_without_progress_bar(mycli: DummyMyCli, cli_args: DummyCliArgs) -> int: + return batch_mode.main_batch_without_progress_bar(cast(Any, mycli), cast(Any, cli_args)) + + +def main_batch_from_stdin(mycli: DummyMyCli, cli_args: DummyCliArgs) -> int: + return batch_mode.main_batch_from_stdin(cast(Any, mycli), cast(Any, cli_args)) + + +def make_fake_sys(stdin_tty: bool, stderr_tty: bool | None = None) -> SimpleNamespace: + stderr = SimpleNamespace(isatty=lambda: stderr_tty) if stderr_tty is not None else object() + return SimpleNamespace( + stdin=SimpleNamespace(isatty=lambda: stdin_tty), + stderr=stderr, + exit=sys.exit, + ) + + +def patch_progress_mode(monkeypatch, mycli_main, mycli_main_batch) -> None: + DummyProgressBar.calls.clear() + monkeypatch.setattr(mycli_main_batch, 'ProgressBar', DummyProgressBar) + monkeypatch.setattr(mycli_main_batch.prompt_toolkit.output, 'create_output', lambda **kwargs: object()) + fake_sys = make_fake_sys(stdin_tty=False, stderr_tty=True) + monkeypatch.setattr(mycli_main, 'sys', fake_sys) + monkeypatch.setattr(mycli_main_batch, 'sys', fake_sys) + + +def invoke_click_batch( + runner: CliRunner, + mycli_main, + contents: str, + args: list[str] | None = None, +): + with NamedTemporaryFile(prefix=TEMPFILE_PREFIX, mode='w', delete=False) as batch_file: + batch_file.write(contents) + batch_file.flush() + + try: + result = runner.invoke( + mycli_main.click_entrypoint, + args=['--batch', batch_file.name] + (args or []), + ) + return result, batch_file.name + finally: + if os.path.exists(batch_file.name): + os.remove(batch_file.name) + + +def write_batch_file(tmp_path: Path, contents: str) -> str: + batch_path = tmp_path / 'batch.sql' + batch_path.write_text(contents, encoding='utf-8') + return str(batch_path) + + +def open_checkpoint_file(tmp_path: Path, contents: str) -> TextIOWrapper: + checkpoint_path = tmp_path / 'checkpoint.sql' + checkpoint_path.write_text(contents, encoding='utf-8') + return checkpoint_path.open('a', encoding='utf-8') + + +def test_replay_checkpoint_file_returns_zero_without_replayable_batch(tmp_path: Path) -> None: + batch_path = write_batch_file(tmp_path, 'select 1;\n') + + assert batch_mode.replay_checkpoint_file(batch_path, None, resume=True) == 0 + + with open_checkpoint_file(tmp_path, 'select 1;\n') as checkpoint: + with pytest.raises(batch_mode.CheckpointReplayError, match='incompatible with reading from the standard input'): + batch_mode.replay_checkpoint_file('-', checkpoint, resume=True) + + +def test_replay_checkpoint_file_rejects_checkpoint_longer_than_batch(tmp_path: Path) -> None: + batch_path = write_batch_file(tmp_path, 'select 1;\n') + + with open_checkpoint_file(tmp_path, 'select 1;\nselect 2;\n') as checkpoint: + with pytest.raises(batch_mode.CheckpointReplayError, match='Checkpoint script longer than batch script.'): + batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True) + + +def test_replay_checkpoint_file_rejects_batch_read_error(monkeypatch, tmp_path: Path) -> None: + batch_path = write_batch_file(tmp_path, 'select 1;\n') + + monkeypatch.setattr(batch_mode, 'statements_from_filehandle', lambda _handle: (_ for _ in ()).throw(ValueError('bad batch'))) + + with open_checkpoint_file(tmp_path, 'select 1;\n') as checkpoint: + with pytest.raises(batch_mode.CheckpointReplayError, match=f'Error reading --batch file: {batch_path}: bad batch'): + batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True) + + +def test_replay_checkpoint_file_rejects_batch_iteration_error(monkeypatch, tmp_path: Path) -> None: + batch_path = write_batch_file(tmp_path, 'select 1;\n') + + def raise_on_next(): + raise ValueError('bad batch iterator') + yield + + def fake_statements_from_filehandle(handle): + if handle.name == batch_path: + return raise_on_next() + return iter([('select 1;', 0)]) + + monkeypatch.setattr(batch_mode, 'statements_from_filehandle', fake_statements_from_filehandle) + + with open_checkpoint_file(tmp_path, 'select 1;\n') as checkpoint: + with pytest.raises(batch_mode.CheckpointReplayError, match=f'Error reading --batch file: {batch_path}: bad batch iterator'): + batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True) + + +def test_replay_checkpoint_file_rejects_checkpoint_read_error(monkeypatch, tmp_path: Path) -> None: + batch_path = write_batch_file(tmp_path, 'select 1;\n') + + def fake_statements_from_filehandle(handle): + if handle.name == batch_path: + return iter([('select 1;', 0)]) + return (_ for _ in ()).throw(ValueError('bad checkpoint')) + + monkeypatch.setattr(batch_mode, 'statements_from_filehandle', fake_statements_from_filehandle) + + with open_checkpoint_file(tmp_path, 'select 1;\n') as checkpoint: + with pytest.raises(batch_mode.CheckpointReplayError, match=f'Error reading --checkpoint file: {checkpoint.name}: bad checkpoint'): + batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True) + + +def test_replay_checkpoint_file_rejects_missing_files(tmp_path: Path) -> None: + batch_path = str(tmp_path / 'missing.sql') + + with open_checkpoint_file(tmp_path, 'select 1;\n') as checkpoint: + with pytest.raises(batch_mode.CheckpointReplayError, match='FileNotFoundError'): + batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True) + + +def test_replay_checkpoint_file_rejects_open_errors(monkeypatch, tmp_path: Path) -> None: + batch_path = write_batch_file(tmp_path, 'select 1;\n') + + monkeypatch.setattr(batch_mode.click, 'open_file', lambda *_args, **_kwargs: (_ for _ in ()).throw(OSError('open failed'))) + + with open_checkpoint_file(tmp_path, 'select 1;\n') as checkpoint: + with pytest.raises(batch_mode.CheckpointReplayError, match='OSError'): + batch_mode.replay_checkpoint_file(batch_path, checkpoint, resume=True) + + +@pytest.mark.parametrize( + ('format_name', 'batch_counter', 'expected'), + ( + ('csv', 1, 'csv-noheader'), + ('tsv', 1, 'tsv_noheader'), + ('table', 1, 'ascii'), + ('vertical', 1, 'tsv'), + ('csv', 0, 'csv'), + ('tsv', 0, 'tsv'), + ('table', 0, 'ascii'), + ('vertical', 0, 'tsv'), + ), +) +def test_dispatch_batch_statements_sets_expected_output_format( + format_name: str, + batch_counter: int, + expected: str, +) -> None: + mycli = DummyMyCli() + cli_args = DummyCliArgs(format=format_name, checkpoint='cp') + + dispatch_batch_statements(mycli, cli_args, 'select 1;', batch_counter) + + assert mycli.main_formatter.format_name == expected + assert mycli.ran_queries == [('select 1;', 'cp', True)] + + +def test_dispatch_batch_statements_confirms_destructive_queries_before_running(monkeypatch) -> None: + mycli = DummyMyCli(destructive_warning=True) + cli_args = DummyCliArgs(noninteractive=False) + opened_tty = object() + + monkeypatch.setattr(batch_mode, 'is_destructive', lambda _keywords, _statement: True) + monkeypatch.setattr(batch_mode, 'confirm_destructive_query', lambda _keywords, _statement: True) + monkeypatch.setattr(batch_mode, 'open', lambda _path: opened_tty, raising=False) + monkeypatch.setattr(batch_mode, 'sys', SimpleNamespace(stdin=None)) + + dispatch_batch_statements(mycli, cli_args, 'drop table demo;', 0) + + assert batch_mode.sys.stdin is opened_tty + assert mycli.ran_queries == [('drop table demo;', None, True)] + + +def test_dispatch_batch_statements_skips_query_when_destructive_confirmation_is_rejected(monkeypatch) -> None: + mycli = DummyMyCli(destructive_warning=True) + cli_args = DummyCliArgs(noninteractive=False) + + monkeypatch.setattr(batch_mode, 'is_destructive', lambda _keywords, _statement: True) + monkeypatch.setattr(batch_mode, 'confirm_destructive_query', lambda _keywords, _statement: False) + monkeypatch.setattr(batch_mode, 'open', lambda _path: object(), raising=False) + monkeypatch.setattr(batch_mode, 'sys', SimpleNamespace(stdin=None)) + + dispatch_batch_statements(mycli, cli_args, 'drop table demo;', 0) + + assert mycli.ran_queries == [] + + +def test_dispatch_batch_statements_raises_when_tty_cannot_be_opened(monkeypatch) -> None: + mycli = DummyMyCli(destructive_warning=True) + cli_args = DummyCliArgs(noninteractive=False) + + monkeypatch.setattr(batch_mode, 'is_destructive', lambda _keywords, _statement: True) + monkeypatch.setattr(batch_mode, 'open', lambda _path: (_ for _ in ()).throw(OSError('tty unavailable')), raising=False) + + with pytest.raises(OSError, match='tty unavailable'): + dispatch_batch_statements(mycli, cli_args, 'drop table demo;', 0) + + assert mycli.logger.warning_messages == ['Unable to open TTY as stdin.'] + + +def test_dispatch_batch_statements_sleeps_and_reraises_query_errors(monkeypatch) -> None: + mycli = DummyMyCli(run_query_error=RuntimeError('boom')) + cli_args = DummyCliArgs(throttle=0.25) + sleep_calls: list[float] = [] + secho_calls: list[tuple[str, bool, str]] = [] + + monkeypatch.setattr(batch_mode.time, 'sleep', lambda seconds: sleep_calls.append(seconds)) + monkeypatch.setattr( + batch_mode.click, + 'secho', + lambda message, err, fg: secho_calls.append((message, err, fg)), + ) + + with pytest.raises(RuntimeError, match='boom'): + dispatch_batch_statements(mycli, cli_args, 'select 1;', 1) + + assert sleep_calls == [0.25] + assert secho_calls == [] + + +def test_main_batch_with_progress_bar_returns_error_when_batch_is_missing() -> None: + assert main_batch_with_progress_bar(DummyMyCli(), DummyCliArgs()) == 1 + + +def test_main_batch_with_progress_bar_rejects_non_files(monkeypatch, tmp_path) -> None: + messages: list[tuple[str, bool, str]] = [] + cli_args = DummyCliArgs(batch=str(tmp_path)) + + monkeypatch.setattr(batch_mode.click, 'secho', lambda message, err, fg: messages.append((message, err, fg))) + monkeypatch.setattr(batch_mode, 'sys', make_fake_sys(stdin_tty=True)) + + result = main_batch_with_progress_bar(DummyMyCli(), cli_args) + + assert result == 1 + assert messages == [('--progress is only compatible with a plain file.', True, 'red')] + + +def test_main_batch_with_progress_bar_handles_open_errors(monkeypatch) -> None: + messages: list[tuple[str, bool, str]] = [] + cli_args = DummyCliArgs(batch='missing.sql') + + monkeypatch.setattr(batch_mode.os.path, 'exists', lambda _path: False) + monkeypatch.setattr(batch_mode.click, 'open_file', lambda _path: (_ for _ in ()).throw(FileNotFoundError())) + monkeypatch.setattr(batch_mode.click, 'secho', lambda message, err, fg: messages.append((message, err, fg))) + monkeypatch.setattr(batch_mode, 'sys', make_fake_sys(stdin_tty=True)) + + result = main_batch_with_progress_bar(DummyMyCli(), cli_args) + + assert result == 1 + assert messages == [('Failed to open --batch file: missing.sql', True, 'red')] + + +def test_main_batch_with_progress_bar_handles_counting_value_errors(monkeypatch) -> None: + messages: list[tuple[str, bool, str]] = [] + count_handle = DummyFile('count') + cli_args = DummyCliArgs(batch='statements.sql') + + monkeypatch.setattr(batch_mode.os.path, 'exists', lambda _path: False) + monkeypatch.setattr(batch_mode.click, 'open_file', lambda _path: count_handle) + monkeypatch.setattr(batch_mode, 'statements_from_filehandle', lambda _handle: (_ for _ in ()).throw(ValueError('bad sql'))) + monkeypatch.setattr(batch_mode.click, 'secho', lambda message, err, fg: messages.append((message, err, fg))) + monkeypatch.setattr(batch_mode, 'sys', make_fake_sys(stdin_tty=True)) + + result = main_batch_with_progress_bar(DummyMyCli(), cli_args) + + assert result == 1 + assert messages == [('Error reading --batch file: statements.sql: bad sql', True, 'red')] + + +def test_main_batch_with_progress_bar_processes_all_statements(monkeypatch) -> None: + messages: list[tuple[str, bool, str]] = [] + count_handle = DummyFile('count') + run_handle = DummyFile('run') + open_calls: list[str] = [] + dispatch_calls: list[tuple[str, int]] = [] + cli_args = DummyCliArgs(batch='statements.sql') + + def fake_open_file(path: str) -> DummyFile: + open_calls.append(path) + return count_handle if len(open_calls) == 1 else run_handle + + def fake_statements_from_filehandle(handle: DummyFile): + if handle is count_handle: + return iter([('select 1;', 0), ('select 2;', 1)]) + return iter([('select 1;', 0), ('select 2;', 1)]) + + DummyProgressBar.calls.clear() + monkeypatch.setattr(batch_mode.os.path, 'exists', lambda _path: False) + monkeypatch.setattr(batch_mode.click, 'open_file', fake_open_file) + monkeypatch.setattr(batch_mode, 'statements_from_filehandle', fake_statements_from_filehandle) + monkeypatch.setattr( + batch_mode, + 'dispatch_batch_statements', + lambda _mycli, _cli_args, statement, counter: dispatch_calls.append((statement, counter)), + ) + monkeypatch.setattr(batch_mode, 'ProgressBar', DummyProgressBar) + monkeypatch.setattr(batch_mode.prompt_toolkit.output, 'create_output', lambda **_kwargs: object()) + monkeypatch.setattr(batch_mode.click, 'secho', lambda message, err, fg: messages.append((message, err, fg))) + monkeypatch.setattr(batch_mode, 'sys', make_fake_sys(stdin_tty=False)) + + result = main_batch_with_progress_bar(DummyMyCli(), cli_args) + + assert result == 0 + assert messages == [('Ignoring STDIN since --batch was also given.', True, 'yellow')] + assert dispatch_calls == [('select 1;', 0), ('select 2;', 1)] + assert DummyProgressBar.calls == [[0, 1]] + assert count_handle.closed is True + assert run_handle.closed is True + + +def test_main_batch_with_progress_bar_returns_error_when_dispatch_fails(monkeypatch) -> None: + messages: list[tuple[str, bool, str]] = [] + count_handle = DummyFile('count') + run_handle = DummyFile('run') + open_calls = 0 + cli_args = DummyCliArgs(batch='statements.sql') + + def fake_open_file(_path: str) -> DummyFile: + nonlocal open_calls + open_calls += 1 + return count_handle if open_calls == 1 else run_handle + + def fake_statements_from_filehandle(handle: DummyFile): + if handle is count_handle: + return iter([('select 1;', 0)]) + return iter([('select 1;', 0)]) + + monkeypatch.setattr(batch_mode.os.path, 'exists', lambda _path: False) + monkeypatch.setattr(batch_mode.click, 'open_file', fake_open_file) + monkeypatch.setattr(batch_mode, 'statements_from_filehandle', fake_statements_from_filehandle) + monkeypatch.setattr(batch_mode, 'ProgressBar', DummyProgressBar) + monkeypatch.setattr(batch_mode.prompt_toolkit.output, 'create_output', lambda **_kwargs: object()) + monkeypatch.setattr( + batch_mode, + 'dispatch_batch_statements', + lambda _mycli, _cli_args, _statement, _counter: (_ for _ in ()).throw(OSError('dispatch failed')), + ) + monkeypatch.setattr(batch_mode.click, 'secho', lambda message, err, fg: messages.append((message, err, fg))) + monkeypatch.setattr(batch_mode, 'sys', make_fake_sys(stdin_tty=True)) + + result = main_batch_with_progress_bar(DummyMyCli(), cli_args) + + assert result == 1 + assert messages == [('dispatch failed', True, 'red')] + assert run_handle.closed is True + + +def test_main_batch_without_progress_bar_returns_error_when_batch_is_missing() -> None: + assert main_batch_without_progress_bar(DummyMyCli(), DummyCliArgs()) == 1 + + +def test_main_batch_without_progress_bar_handles_open_errors(monkeypatch) -> None: + messages: list[tuple[str, bool, str]] = [] + cli_args = DummyCliArgs(batch='missing.sql') + + monkeypatch.setattr(batch_mode.click, 'open_file', lambda _path: (_ for _ in ()).throw(FileNotFoundError())) + monkeypatch.setattr(batch_mode.click, 'secho', lambda message, err, fg: messages.append((message, err, fg))) + monkeypatch.setattr(batch_mode, 'sys', make_fake_sys(stdin_tty=True)) + + result = main_batch_without_progress_bar(DummyMyCli(), cli_args) + + assert result == 1 + assert messages == [('Failed to open --batch file: missing.sql', True, 'red')] + + +def test_main_batch_without_progress_bar_processes_statements(monkeypatch) -> None: + messages: list[tuple[str, bool, str]] = [] + batch_handle = DummyFile('run') + dispatch_calls: list[tuple[str, int]] = [] + cli_args = DummyCliArgs(batch='statements.sql') + + monkeypatch.setattr(batch_mode.click, 'open_file', lambda _path: batch_handle) + monkeypatch.setattr(batch_mode, 'statements_from_filehandle', lambda _handle: iter([('select 1;', 0), ('select 2;', 1)])) + monkeypatch.setattr( + batch_mode, + 'dispatch_batch_statements', + lambda _mycli, _cli_args, statement, counter: dispatch_calls.append((statement, counter)), + ) + monkeypatch.setattr(batch_mode.click, 'secho', lambda message, err, fg: messages.append((message, err, fg))) + monkeypatch.setattr(batch_mode, 'sys', make_fake_sys(stdin_tty=False)) + + result = main_batch_without_progress_bar(DummyMyCli(), cli_args) + + assert result == 0 + assert messages == [('Ignoring STDIN since --batch was also given.', True, 'red')] + assert dispatch_calls == [('select 1;', 0), ('select 2;', 1)] + assert batch_handle.closed is True + + +def test_main_batch_without_progress_bar_skips_checkpoint_prefix(monkeypatch, tmp_path: Path) -> None: + batch_path = write_batch_file(tmp_path, 'select 1;\nselect 2;\nselect 3;\n') + dispatch_calls: list[tuple[str, int]] = [] + + monkeypatch.setattr( + batch_mode, + 'dispatch_batch_statements', + lambda _mycli, _cli_args, statement, counter: dispatch_calls.append((statement, counter)), + ) + monkeypatch.setattr(batch_mode, 'sys', make_fake_sys(stdin_tty=True)) + + with open_checkpoint_file(tmp_path, 'select 1;\nselect 2;\n') as checkpoint: + cli_args = DummyCliArgs(batch=batch_path, checkpoint=checkpoint, resume=True) + + result = main_batch_without_progress_bar(DummyMyCli(), cli_args) + + assert result == 0 + assert dispatch_calls == [('select 3;', 2)] + + +def test_main_batch_without_progress_bar_skips_only_matching_duplicate_prefix(monkeypatch, tmp_path: Path) -> None: + batch_path = write_batch_file(tmp_path, 'select 1;\nselect 1;\nselect 2;\n') + dispatch_calls: list[tuple[str, int]] = [] + + monkeypatch.setattr( + batch_mode, + 'dispatch_batch_statements', + lambda _mycli, _cli_args, statement, counter: dispatch_calls.append((statement, counter)), + ) + monkeypatch.setattr(batch_mode, 'sys', make_fake_sys(stdin_tty=True)) + + with open_checkpoint_file(tmp_path, 'select 1;\n') as checkpoint: + cli_args = DummyCliArgs(batch=batch_path, checkpoint=checkpoint, resume=True) + + result = main_batch_without_progress_bar(DummyMyCli(), cli_args) + + assert result == 0 + assert dispatch_calls == [('select 1;', 1), ('select 2;', 2)] + + +def test_main_batch_without_progress_bar_fails_on_mismatched_checkpoint(monkeypatch, tmp_path: Path) -> None: + batch_path = write_batch_file(tmp_path, 'select 1;\nselect 2;\n') + dispatch_calls: list[tuple[str, int]] = [] + + monkeypatch.setattr( + batch_mode, + 'dispatch_batch_statements', + lambda _mycli, _cli_args, statement, counter: dispatch_calls.append((statement, counter)), + ) + monkeypatch.setattr(batch_mode, 'sys', make_fake_sys(stdin_tty=True)) + + with open_checkpoint_file(tmp_path, 'select 9;\n') as checkpoint: + cli_args = DummyCliArgs(batch=batch_path, checkpoint=checkpoint, resume=True) + + result = main_batch_without_progress_bar(DummyMyCli(), cli_args) + + assert result == 1 + assert dispatch_calls == [] + + +def test_main_batch_without_progress_bar_succeeds_when_checkpoint_skips_all(monkeypatch, tmp_path: Path) -> None: + batch_path = write_batch_file(tmp_path, 'select 1;\nselect 2;\n') + dispatch_calls: list[tuple[str, int]] = [] + + monkeypatch.setattr( + batch_mode, + 'dispatch_batch_statements', + lambda _mycli, _cli_args, statement, counter: dispatch_calls.append((statement, counter)), + ) + monkeypatch.setattr(batch_mode, 'sys', make_fake_sys(stdin_tty=True)) + + with open_checkpoint_file(tmp_path, 'select 1;\nselect 2;\n') as checkpoint: + cli_args = DummyCliArgs(batch=batch_path, checkpoint=checkpoint, resume=True) + + result = main_batch_without_progress_bar(DummyMyCli(), cli_args) + + assert result == 0 + assert dispatch_calls == [] + + +def test_main_batch_with_progress_bar_skips_checkpoint_prefix_and_counts_all_statements(monkeypatch, tmp_path: Path) -> None: + batch_path = write_batch_file(tmp_path, 'select 1;\nselect 2;\nselect 3;\n') + dispatch_calls: list[tuple[str, int]] = [] + + DummyProgressBar.calls.clear() + monkeypatch.setattr(batch_mode, 'ProgressBar', DummyProgressBar) + monkeypatch.setattr(batch_mode.prompt_toolkit.output, 'create_output', lambda **_kwargs: object()) + monkeypatch.setattr( + batch_mode, + 'dispatch_batch_statements', + lambda _mycli, _cli_args, statement, counter: dispatch_calls.append((statement, counter)), + ) + monkeypatch.setattr(batch_mode, 'sys', make_fake_sys(stdin_tty=True)) + + with open_checkpoint_file(tmp_path, 'select 1;\n') as checkpoint: + cli_args = DummyCliArgs(batch=batch_path, checkpoint=checkpoint, resume=True) + + result = main_batch_with_progress_bar(DummyMyCli(), cli_args) + + assert result == 0 + assert dispatch_calls == [('select 2;', 1), ('select 3;', 2)] + assert DummyProgressBar.calls == [[0, 1, 2]] + + +def test_main_batch_with_progress_bar_returns_error_when_checkpoint_replay_fails(monkeypatch, tmp_path: Path) -> None: + batch_path = write_batch_file(tmp_path, 'select 1;\n') + messages: list[tuple[str, bool, str]] = [] + + monkeypatch.setattr(batch_mode.click, 'secho', lambda message, err, fg: messages.append((message, err, fg))) + monkeypatch.setattr(batch_mode, 'sys', make_fake_sys(stdin_tty=True)) + + with open_checkpoint_file(tmp_path, 'select 9;\n') as checkpoint: + cli_args = DummyCliArgs(batch=batch_path, checkpoint=checkpoint, resume=True) + + result = main_batch_with_progress_bar(DummyMyCli(), cli_args) + + assert result == 1 + assert messages == [(f'Error replaying --checkpoint file: {checkpoint.name}: Statement mismatch: select 9;.', True, 'red')] + + +def test_main_batch_without_progress_bar_returns_error_when_iteration_fails(monkeypatch) -> None: + messages: list[tuple[str, bool, str]] = [] + batch_handle = DummyFile('run') + cli_args = DummyCliArgs(batch='statements.sql') + + monkeypatch.setattr(batch_mode.click, 'open_file', lambda _path: batch_handle) + monkeypatch.setattr(batch_mode, 'statements_from_filehandle', lambda _handle: (_ for _ in ()).throw(ValueError('bad sql'))) + monkeypatch.setattr(batch_mode.click, 'secho', lambda message, err, fg: messages.append((message, err, fg))) + monkeypatch.setattr(batch_mode, 'sys', make_fake_sys(stdin_tty=True)) + + result = main_batch_without_progress_bar(DummyMyCli(), cli_args) + + assert result == 1 + assert messages == [('bad sql', True, 'red')] + + +def test_main_batch_from_stdin_processes_statements(monkeypatch) -> None: + dispatch_calls: list[tuple[str, int]] = [] + batch_handle = object() + + monkeypatch.setattr(batch_mode.click, 'get_text_stream', lambda _name: batch_handle) + monkeypatch.setattr(batch_mode, 'statements_from_filehandle', lambda _handle: iter([('select 1;', 0), ('select 2;', 1)])) + monkeypatch.setattr( + batch_mode, + 'dispatch_batch_statements', + lambda _mycli, _cli_args, statement, counter: dispatch_calls.append((statement, counter)), + ) + + result = main_batch_from_stdin(DummyMyCli(), DummyCliArgs()) + + assert result == 0 + assert dispatch_calls == [('select 1;', 0), ('select 2;', 1)] + + +def test_main_batch_from_stdin_returns_error_for_value_errors(monkeypatch) -> None: + messages: list[tuple[str, bool, str]] = [] + + monkeypatch.setattr(batch_mode.click, 'get_text_stream', lambda _name: object()) + monkeypatch.setattr(batch_mode, 'statements_from_filehandle', lambda _handle: (_ for _ in ()).throw(ValueError('bad stdin'))) + monkeypatch.setattr(batch_mode.click, 'secho', lambda message, err, fg: messages.append((message, err, fg))) + + result = main_batch_from_stdin(DummyMyCli(), DummyCliArgs()) + + assert result == 1 + assert messages == [('bad stdin', True, 'red')] + + +@pytest.mark.parametrize( + ('contents', 'extra_args', 'expected_queries', 'expected_progress'), + ( + ('select 2;', [], ['select 2;'], None), + ('select 2; select 3;\nselect 4;\n', [], ['select 2;', 'select 3;', 'select 4;'], None), + ('select 2;\nselect 2;\nselect 2;\n', ['--progress'], ['select 2;', 'select 2;', 'select 2;'], [[0, 1, 2]]), + ('select 2; select 3;\nselect 4;\n', ['--progress'], ['select 2;', 'select 3;', 'select 4;'], [[0, 1, 2]]), + ), +) +def test_click_batch_file_modes(monkeypatch, contents: str, extra_args: list[str], expected_queries: list[str], expected_progress) -> None: + mycli_main, mycli_main_batch, MockMyCli = noninteractive_mock_mycli(monkeypatch) + runner = CliRunner() + MockMyCli.ran_queries = [] + + if '--progress' in extra_args: + patch_progress_mode(monkeypatch, mycli_main, mycli_main_batch) + + result, _batch_file_name = invoke_click_batch(runner, mycli_main, contents, extra_args) + + assert result.exit_code == 0 + assert MockMyCli.ran_queries == expected_queries + if expected_progress is not None: + assert DummyProgressBar.calls == expected_progress + + +def test_click_batch_file_skips_checkpoint_prefix(monkeypatch, tmp_path: Path) -> None: + mycli_main, _mycli_main_batch, MockMyCli = noninteractive_mock_mycli(monkeypatch) + runner = CliRunner() + MockMyCli.ran_queries = [] + checkpoint_path = tmp_path / 'checkpoint.sql' + checkpoint_path.write_text('select 2;\n', encoding='utf-8') + + result, _batch_file_name = invoke_click_batch( + runner, + mycli_main, + 'select 2;\nselect 3;\n', + [f'--checkpoint={checkpoint_path}', '--resume'], + ) + + assert result.exit_code == 0 + assert MockMyCli.ran_queries == ['select 3;'] + + +def test_batch_file_with_progress_requires_plain_file(monkeypatch, tmp_path) -> None: + mycli_main, mycli_main_batch, MockMyCli = noninteractive_mock_mycli(monkeypatch) + runner = CliRunner() + + patch_progress_mode(monkeypatch, mycli_main, mycli_main_batch) + + result = runner.invoke( + mycli_main.click_entrypoint, + args=['--batch', str(tmp_path), '--progress'], + ) + + assert result.exit_code != 0 + assert '--progress is only compatible with a plain file.' in result.output + assert MockMyCli.ran_queries == [] + + +def test_batch_file_open_error(monkeypatch) -> None: + mycli_main, _mycli_main_batch, MockMyCli = noninteractive_mock_mycli(monkeypatch) + runner = CliRunner() + + result = runner.invoke(mycli_main.click_entrypoint, args=['--batch', 'definitely_missing_file.sql']) + + assert result.exit_code != 0 + assert 'Failed to open --batch file' in result.output + assert MockMyCli.ran_queries == [] diff --git a/test/pytests/test_main_modes_execute.py b/test/pytests/test_main_modes_execute.py new file mode 100644 index 00000000..2b36fe31 --- /dev/null +++ b/test/pytests/test_main_modes_execute.py @@ -0,0 +1,127 @@ +from __future__ import annotations + +from dataclasses import dataclass +from types import SimpleNamespace +from typing import Any, cast + +import pytest + +import mycli.main_modes.execute as execute_mode + + +@dataclass +class DummyCliArgs: + execute: str | None + format: str = 'tsv' + batch: str | None = None + checkpoint: str | None = None + + +@dataclass +class DummyFormatter: + format_name: str | None = None + + +class DummyMyCli: + def __init__(self, run_query_error: Exception | None = None) -> None: + self.main_formatter = DummyFormatter() + self.run_query_error = run_query_error + self.ran_queries: list[tuple[str, str | None]] = [] + + def run_query(self, query: str, checkpoint: str | None = None) -> None: + if self.run_query_error is not None: + raise self.run_query_error + self.ran_queries.append((query, checkpoint)) + + +def main_execute_from_cli(mycli: DummyMyCli, cli_args: DummyCliArgs) -> int: + return execute_mode.main_execute_from_cli(cast(Any, mycli), cast(Any, cli_args)) + + +def fake_sys(stdin_tty: bool) -> SimpleNamespace: + return SimpleNamespace(stdin=SimpleNamespace(isatty=lambda: stdin_tty)) + + +def test_main_execute_from_cli_returns_error_when_execute_is_missing() -> None: + assert main_execute_from_cli(DummyMyCli(), DummyCliArgs(execute=None)) == 1 + + +@pytest.mark.parametrize( + ('format_name', 'original_sql', 'expected_format', 'expected_sql'), + ( + ('csv', r'select 1\G', 'csv', 'select 1'), + ('tsv', r'select 2\G', 'tsv', 'select 2'), + ('table', r'select 3\G', 'ascii', 'select 3'), + ('vertical', r'select 4\G', 'tsv', r'select 4\G'), + ), +) +def test_main_execute_from_cli_sets_format_and_runs_query( + monkeypatch, + format_name: str, + original_sql: str, + expected_format: str, + expected_sql: str, +) -> None: + secho_calls: list[tuple[str, bool, str]] = [] + mycli = DummyMyCli() + cli_args = DummyCliArgs( + execute=original_sql, + format=format_name, + batch='batch.sql', + checkpoint='cp', + ) + + monkeypatch.setattr(execute_mode, 'sys', fake_sys(stdin_tty=False)) + monkeypatch.setattr( + execute_mode.click, + 'secho', + lambda message, err, fg: secho_calls.append((message, err, fg)), + ) + + result = main_execute_from_cli(mycli, cli_args) + + assert result == 0 + assert mycli.main_formatter.format_name == expected_format + assert mycli.ran_queries == [(expected_sql, 'cp')] + assert secho_calls == [ + ('Ignoring STDIN since --execute was also given.', True, 'red'), + ('Ignoring --batch since --execute was also given.', True, 'red'), + ] + + +def test_main_execute_from_cli_does_not_warn_when_stdin_is_tty_and_batch_is_unset(monkeypatch) -> None: + secho_calls: list[tuple[str, bool, str]] = [] + mycli = DummyMyCli() + + monkeypatch.setattr(execute_mode, 'sys', fake_sys(stdin_tty=True)) + monkeypatch.setattr( + execute_mode.click, + 'secho', + lambda message, err, fg: secho_calls.append((message, err, fg)), + ) + + result = main_execute_from_cli(mycli, DummyCliArgs(execute='select 1', format='csv')) + + assert result == 0 + assert mycli.main_formatter.format_name == 'csv' + assert mycli.ran_queries == [('select 1', None)] + assert secho_calls == [] + + +def test_main_execute_from_cli_reports_query_errors(monkeypatch) -> None: + secho_calls: list[tuple[str, bool, str]] = [] + mycli = DummyMyCli(run_query_error=RuntimeError('boom')) + + monkeypatch.setattr(execute_mode, 'sys', fake_sys(stdin_tty=True)) + monkeypatch.setattr( + execute_mode.click, + 'secho', + lambda message, err, fg: secho_calls.append((message, err, fg)), + ) + + result = main_execute_from_cli(mycli, DummyCliArgs(execute='select 1', format='table')) + + assert result == 1 + assert mycli.main_formatter.format_name == 'ascii' + assert mycli.ran_queries == [] + assert secho_calls == [('boom', True, 'red')] diff --git a/test/pytests/test_main_modes_list_dsn.py b/test/pytests/test_main_modes_list_dsn.py new file mode 100644 index 00000000..359a4b93 --- /dev/null +++ b/test/pytests/test_main_modes_list_dsn.py @@ -0,0 +1,106 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, cast + +import mycli.main_modes.list_dsn as list_dsn_mode + + +@dataclass +class DummyCliArgs: + verbose: int = 0 + + +class DummyConfig: + def __init__(self, value: dict[str, str] | Exception) -> None: + self.value = value + + def __getitem__(self, key: str) -> dict[str, str]: + assert key == 'alias_dsn' + if isinstance(self.value, Exception): + raise self.value + return self.value + + +class DummyMyCli: + def __init__(self, config: Any) -> None: + self.config = config + self.verbosity = 0 + + +def main_list_dsn(mycli: DummyMyCli) -> int: + return list_dsn_mode.main_list_dsn(cast(Any, mycli)) + + +def test_main_list_dsn_lists_aliases_without_values(monkeypatch) -> None: + secho_calls: list[tuple[str, bool | None, str | None]] = [] + mycli = DummyMyCli(DummyConfig({'prod': 'mysql://u:p@h/db', 'staging': 'mysql://u2:p2@h2/db2'})) + + monkeypatch.setattr( + list_dsn_mode.click, + 'secho', + lambda message, err=None, fg=None: secho_calls.append((message, err, fg)), + ) + + result = main_list_dsn(mycli) + + assert result == 0 + assert secho_calls == [ + ('prod', None, None), + ('staging', None, None), + ] + + +def test_main_list_dsn_lists_aliases_with_values_in_verbose_mode(monkeypatch) -> None: + secho_calls: list[tuple[str, bool | None, str | None]] = [] + mycli = DummyMyCli(DummyConfig({'prod': 'mysql://u:p@h/db'})) + mycli.verbosity = 1 + + monkeypatch.setattr( + list_dsn_mode.click, + 'secho', + lambda message, err=None, fg=None: secho_calls.append((message, err, fg)), + ) + + result = main_list_dsn(mycli) + + assert result == 0 + assert secho_calls == [('prod : mysql://u:p@h/db', None, None)] + + +def test_main_list_dsn_reports_invalid_alias_section(monkeypatch) -> None: + secho_calls: list[tuple[str, bool | None, str | None]] = [] + mycli = DummyMyCli(DummyConfig(KeyError('alias_dsn'))) + + monkeypatch.setattr( + list_dsn_mode.click, + 'secho', + lambda message, err=None, fg=None: secho_calls.append((message, err, fg)), + ) + + result = main_list_dsn(mycli) + + assert result == 1 + assert secho_calls == [ + ( + 'Invalid DSNs found in the config file. Please check the "[alias_dsn]" section in myclirc.', + True, + 'red', + ) + ] + + +def test_main_list_dsn_reports_other_config_errors(monkeypatch) -> None: + secho_calls: list[tuple[str, bool | None, str | None]] = [] + mycli = DummyMyCli(DummyConfig(RuntimeError('boom'))) + + monkeypatch.setattr( + list_dsn_mode.click, + 'secho', + lambda message, err=None, fg=None: secho_calls.append((message, err, fg)), + ) + + result = main_list_dsn(mycli) + + assert result == 1 + assert secho_calls == [('boom', True, 'red')] diff --git a/test/pytests/test_main_modes_list_ssh_config.py b/test/pytests/test_main_modes_list_ssh_config.py new file mode 100644 index 00000000..9ff104a4 --- /dev/null +++ b/test/pytests/test_main_modes_list_ssh_config.py @@ -0,0 +1,95 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, cast + +import mycli.main_modes.list_ssh_config as list_ssh_config_mode + + +@dataclass +class DummyCliArgs: + ssh_config_path: str = 'ssh_config' + verbose: int = 0 + + +class DummyMyCli: + def __init__(self, config: Any) -> None: + self.config = config + self.verbosity = 0 + + +class DummySSHConfig: + def __init__(self, hostnames: list[str] | Exception, lookups: dict[str, dict[str, str]] | None = None) -> None: + self.hostnames = hostnames + self.lookups = lookups or {} + + def get_hostnames(self) -> list[str]: + if isinstance(self.hostnames, Exception): + raise self.hostnames + return self.hostnames + + def lookup(self, hostname: str) -> dict[str, str]: + return self.lookups[hostname] + + +def main_list_ssh_config(cli_args: DummyCliArgs) -> int: + mycli = DummyMyCli(config={}) + mycli.verbosity = cli_args.verbose + return list_ssh_config_mode.main_list_ssh_config(cast(Any, mycli), cast(Any, cli_args)) + + +def test_main_list_ssh_config_lists_hostnames(monkeypatch) -> None: + secho_calls: list[tuple[str, bool | None, str | None]] = [] + ssh_config = DummySSHConfig(['prod', 'staging']) + + monkeypatch.setattr(list_ssh_config_mode, 'read_ssh_config', lambda _path: ssh_config) + monkeypatch.setattr( + list_ssh_config_mode.click, + 'secho', + lambda message, err=None, fg=None: secho_calls.append((message, err, fg)), + ) + + result = main_list_ssh_config(DummyCliArgs(verbose=0)) + + assert result == 0 + assert secho_calls == [ + ('prod', None, None), + ('staging', None, None), + ] + + +def test_main_list_ssh_config_lists_verbose_host_details(monkeypatch) -> None: + secho_calls: list[tuple[str, bool | None, str | None]] = [] + ssh_config = DummySSHConfig( + ['prod'], + lookups={'prod': {'hostname': 'db.example.com'}}, + ) + + monkeypatch.setattr(list_ssh_config_mode, 'read_ssh_config', lambda _path: ssh_config) + monkeypatch.setattr( + list_ssh_config_mode.click, + 'secho', + lambda message, err=None, fg=None: secho_calls.append((message, err, fg)), + ) + + result = main_list_ssh_config(DummyCliArgs(verbose=1)) + + assert result == 0 + assert secho_calls == [('prod : db.example.com', None, None)] + + +def test_main_list_ssh_config_reports_host_lookup_errors(monkeypatch) -> None: + secho_calls: list[tuple[str, bool | None, str | None]] = [] + ssh_config = DummySSHConfig(KeyError('bad ssh config')) + + monkeypatch.setattr(list_ssh_config_mode, 'read_ssh_config', lambda _path: ssh_config) + monkeypatch.setattr( + list_ssh_config_mode.click, + 'secho', + lambda message, err=None, fg=None: secho_calls.append((message, err, fg)), + ) + + result = main_list_ssh_config(DummyCliArgs()) + + assert result == 1 + assert secho_calls == [('Error reading ssh config', True, 'red')] diff --git a/test/pytests/test_main_modes_repl.py b/test/pytests/test_main_modes_repl.py new file mode 100644 index 00000000..d7efc544 --- /dev/null +++ b/test/pytests/test_main_modes_repl.py @@ -0,0 +1,1312 @@ +from __future__ import annotations + +import builtins +from collections.abc import Generator, Iterator +from dataclasses import dataclass +from io import StringIO +import os +from types import SimpleNamespace +from typing import Any, Literal, cast + +from prompt_toolkit.formatted_text import to_formatted_text, to_plain_text +import pymysql +import pytest + +import mycli.main_modes.repl as repl_mode +from mycli.packages.sqlresult import SQLResult + + +class DummyLogger: + def __init__(self) -> None: + self.debug_calls: list[tuple[tuple[Any, ...], dict[str, Any]]] = [] + self.error_calls: list[tuple[tuple[Any, ...], dict[str, Any]]] = [] + + def debug(self, *args: Any, **kwargs: Any) -> None: + self.debug_calls.append((args, kwargs)) + + def error(self, *args: Any, **kwargs: Any) -> None: + self.error_calls.append((args, kwargs)) + + +class HashableNamespace: + pass + + +@dataclass +class DummyFormatterWithQuery: + query: str = '' + + +class FakeApp: + def __init__(self, text: str = '', render_counter: int = 0) -> None: + self.current_buffer = SimpleNamespace(text=text) + self.render_counter = render_counter + self.ttimeoutlen: float | None = None + + +class FakePromptOutput: + def __init__(self, columns: int = 80, rows: int = 24) -> None: + self.columns = columns + self.rows = rows + self.bell_count = 0 + + def get_size(self) -> SimpleNamespace: + return SimpleNamespace(columns=self.columns, rows=self.rows) + + def bell(self) -> None: + self.bell_count += 1 + + +class FakePromptSession: + def __init__(self, responses: list[Any] | None = None, columns: int = 80, rows: int = 24) -> None: + self.responses = list(responses or []) + self.output = FakePromptOutput(columns=columns, rows=rows) + self.app = FakeApp() + self.prompt_calls: list[dict[str, Any]] = [] + + def prompt(self, **kwargs: Any) -> str: + self.prompt_calls.append(dict(kwargs)) + if not self.responses: + raise EOFError() + response = self.responses.pop(0) + if isinstance(response, BaseException): + raise response + return response + + +class FakeCursorBase: + def __init__( + self, + rows: list[tuple[Any, ...]] | None = None, + rowcount: int = 0, + warning_count: int = 0, + ) -> None: + self._rows = list(rows or []) + self.rowcount = rowcount + self.warning_count = warning_count + + def __iter__(self) -> Iterator[tuple[Any, ...]]: + return iter(self._rows) + + +class FakeConnection: + def __init__(self, ping_exc: Exception | None = None, cursor_value: Any = 'cursor') -> None: + self.ping_exc = ping_exc + self.cursor_value = cursor_value + self.ping_calls: list[bool] = [] + + def ping(self, reconnect: bool = False) -> None: + self.ping_calls.append(reconnect) + if self.ping_exc is not None: + raise self.ping_exc + + def cursor(self) -> Any: + return self.cursor_value + + +class ReusableLock: + def __enter__(self) -> 'ReusableLock': + return self + + def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> Literal[False]: + return False + + +def sqlresult_generator(*results: SQLResult) -> Generator[SQLResult, None, None]: + for result in results: + yield result + + +class FakeResourceTree: + def __init__(self, files: dict[str, str], path: str | None = None) -> None: + self.files = files + self.path = path + + def joinpath(self, path: str) -> 'FakeResourceTree': + return FakeResourceTree(self.files, path) + + def open(self, mode: str) -> StringIO: + assert self.path is not None + if self.path not in self.files: + raise FileNotFoundError(self.path) + return StringIO(self.files[self.path]) + + +def make_repl_cli(sqlexecute: Any | None = None) -> Any: + cli: Any = HashableNamespace() + cli.logger = DummyLogger() + cli.query_history = [] + cli.last_prompt_message = repl_mode.ANSI('') + cli.last_custom_toolbar_message = repl_mode.ANSI('') + cli.prompt_lines = 0 + cli.default_prompt = r'\t \u@\h:\d> ' + cli.default_prompt_splitln = r'\u@\h\n(\t):\d>' + cli.max_len_prompt = 45 + cli.prompt_format = cli.default_prompt + cli.multiline_continuation_char = '>' + cli.toolbar_format = 'default' + cli.verbosity = -1 + cli.keepalive_ticks = None + cli._keepalive_counter = 0 + cli.auto_vertical_output = False + cli.beep_after_seconds = 0.0 + cli.show_warnings = False + cli.null_string = '' + cli.numeric_alignment = 'right' + cli.binary_display = None + cli.prompt_session = None + cli.post_redirect_command = None + cli.logfile = None + cli.smart_completion = False + cli.config = {'main': {'history_file': '~/.mycli-history-testing'}} + cli.key_bindings = 'emacs' + cli.wider_completion_menu = False + cli.login_path = None + cli.login_path_as_host = False + cli.dsn_alias = None + cli.terminal_tab_title_format = '' + cli.terminal_window_title_format = '' + cli.multiplex_window_title_format = '' + cli.multiplex_pane_title_format = '' + cli._completer_lock = ReusableLock() + cli.completer = object() + cli.syntax_style = 'native' + cli.cli_style = {} + cli.emacs_ttimeoutlen = 1.0 + cli.vi_ttimeoutlen = 2.0 + cli.sandbox_mode = False + cli.destructive_warning = False + cli.destructive_keywords = ['drop'] + cli.llm_prompt_field_truncate = 0 + cli.llm_prompt_section_truncate = 0 + cli.main_formatter = DummyFormatterWithQuery() + cli.redirect_formatter = DummyFormatterWithQuery() + cli.pager_configured = 0 + refresh_calls: list[bool] = [] + output_calls: list[tuple[list[str], Any, bool]] = [] + echo_calls: list[str] = [] + timing_calls: list[tuple[str, bool]] = [] + log_queries: list[str] = [] + cli.refresh_calls = refresh_calls + cli.output_calls = output_calls + cli.echo_calls = echo_calls + cli.timing_calls = timing_calls + cli.log_queries = log_queries + cli.title_calls = 0 + cli.sqlexecute = sqlexecute + cli.get_reserved_space = lambda: 3 + cli.get_last_query = lambda: cli.query_history[-1].query if cli.query_history else None + cli.configure_pager = lambda: setattr(cli, 'pager_configured', cli.pager_configured + 1) + + def refresh_completions(reset: bool = False) -> list[SQLResult]: + cli.refresh_calls.append(reset) + return [SQLResult(status='refresh')] + + cli.refresh_completions = refresh_completions + + def output_timing(timing: str, is_warnings_style: bool = False) -> None: + cli.timing_calls.append((timing, is_warnings_style)) + + cli.output_timing = output_timing + + def log_query(text: str) -> None: + cli.log_queries.append(text) + + cli.log_query = log_query + cli.reconnect = lambda database='': False + + def echo(message: Any, **kwargs: Any) -> None: + cli.echo_calls.append(str(message)) + + cli.echo = echo + + def format_sqlresult(result: SQLResult, **kwargs: Any) -> Iterator[str]: + return iter([str(kwargs.get('max_width')), result.status_plain or 'row']) + + cli.format_sqlresult = format_sqlresult + + def output(formatted: Any, result: Any, is_warnings_style: bool = False) -> None: + cli.output_calls.append((list(formatted), result, is_warnings_style)) + + cli.output = output + return cli + + +def patch_repl_runtime_defaults(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(repl_mode.special, 'set_expanded_output', lambda value: None) + monkeypatch.setattr(repl_mode.special, 'set_forced_horizontal_output', lambda value: None) + monkeypatch.setattr(repl_mode.special, 'is_llm_command', lambda text: False) + monkeypatch.setattr(repl_mode.special, 'is_expanded_output', lambda: False) + monkeypatch.setattr(repl_mode.special, 'is_redirected', lambda: False) + monkeypatch.setattr(repl_mode.special, 'is_timing_enabled', lambda: False) + monkeypatch.setattr(repl_mode.special, 'write_tee', lambda *args, **kwargs: None) + monkeypatch.setattr(repl_mode.special, 'unset_once_if_written', lambda *args, **kwargs: None) + monkeypatch.setattr(repl_mode.special, 'flush_pipe_once_if_written', lambda *args, **kwargs: None) + monkeypatch.setattr(repl_mode.special, 'close_tee', lambda: None) + monkeypatch.setattr(repl_mode, 'handle_editor_command', lambda mycli, text, inputhook, loaded_message_fn: text) + monkeypatch.setattr(repl_mode, 'handle_clip_command', lambda mycli, text: False) + monkeypatch.setattr(repl_mode, 'is_redirect_command', lambda text: False) + monkeypatch.setattr(repl_mode, 'confirm_destructive_query', lambda keywords, text: None) + monkeypatch.setattr(repl_mode, 'need_completion_refresh', lambda text: False) + monkeypatch.setattr(repl_mode, 'need_completion_reset', lambda text: False) + monkeypatch.setattr(repl_mode, 'is_dropping_database', lambda text, dbname: False) + monkeypatch.setattr(repl_mode, 'is_mutating', lambda status: False) + + +def test_complete_while_typing_filter_covers_threshold_and_word_rules(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(repl_mode, 'MIN_COMPLETION_TRIGGER', 3) + monkeypatch.setattr(repl_mode, 'get_app', lambda: SimpleNamespace(current_buffer=SimpleNamespace(text='ab'))) + assert repl_mode.complete_while_typing_filter() is False + + monkeypatch.setattr(repl_mode, 'get_app', lambda: SimpleNamespace(current_buffer=SimpleNamespace(text='abc'))) + assert repl_mode.complete_while_typing_filter() is True + + monkeypatch.setattr(repl_mode, 'get_app', lambda: SimpleNamespace(current_buffer=SimpleNamespace(text='source xyz'))) + assert repl_mode.complete_while_typing_filter() is True + + monkeypatch.setattr(repl_mode, 'get_app', lambda: SimpleNamespace(current_buffer=SimpleNamespace(text='source x/'))) + assert repl_mode.complete_while_typing_filter() is False + + monkeypatch.setattr(repl_mode, 'get_app', lambda: SimpleNamespace(current_buffer=SimpleNamespace(text='\\. abc'))) + assert repl_mode.complete_while_typing_filter() is True + + monkeypatch.setattr(repl_mode, 'get_app', lambda: SimpleNamespace(current_buffer=SimpleNamespace(text='\\. a/'))) + assert repl_mode.complete_while_typing_filter() is False + + monkeypatch.setattr(repl_mode, 'get_app', lambda: SimpleNamespace(current_buffer=SimpleNamespace(text='select abc'))) + assert repl_mode.complete_while_typing_filter() is True + + monkeypatch.setattr(repl_mode, 'get_app', lambda: SimpleNamespace(current_buffer=SimpleNamespace(text='select a!'))) + assert repl_mode.complete_while_typing_filter() is False + + monkeypatch.setattr(repl_mode, 'MIN_COMPLETION_TRIGGER', 1) + assert repl_mode.complete_while_typing_filter() is True + + +def test_repl_create_history(monkeypatch: pytest.MonkeyPatch) -> None: + cli = make_repl_cli() + monkeypatch.setenv('MYCLI_HISTFILE', '~/override-history') + monkeypatch.setattr(repl_mode, 'dir_path_exists', lambda path: True) + monkeypatch.setattr(repl_mode, 'FileHistoryWithTimestamp', lambda path: f'history:{path}') + history = cast(Any, repl_mode._create_history(cli)) + assert history == f'history:{os.path.expanduser("~/override-history")}' + + monkeypatch.delenv('MYCLI_HISTFILE') + monkeypatch.setattr(repl_mode, 'dir_path_exists', lambda path: False) + assert repl_mode._create_history(cli) is None + assert 'Unable to open the history file' in cli.echo_calls[-1] + + +def test_repl_picker_helpers_cover_present_and_missing_resources(monkeypatch: pytest.MonkeyPatch) -> None: + files = { + 'AUTHORS': '* Alice\n* Bob\n', + 'SPONSORS': '* Carol\n', + 'TIPS': '# comment\nTip 1\n\nTip 2\n', + } + monkeypatch.setattr(repl_mode.resources, 'files', lambda package: FakeResourceTree(files)) + monkeypatch.setattr(repl_mode.random, 'choice', lambda seq: seq[0]) + assert repl_mode._contributors_picker() == 'Alice' + assert repl_mode._sponsors_picker() == 'Carol' + assert repl_mode._tips_picker() == 'Tip 1' + + monkeypatch.setattr(repl_mode.resources, 'files', lambda package: FakeResourceTree({})) + assert repl_mode._contributors_picker() == 'our contributors' + assert repl_mode._sponsors_picker() == 'our sponsors' + assert repl_mode._tips_picker() == r'\? or "help" for help!' + + +def test_repl_show_startup_banner_and_prompt_helpers(monkeypatch: pytest.MonkeyPatch) -> None: + cli = make_repl_cli(SimpleNamespace(server_info='Server')) + printed: list[str] = [] + monkeypatch.setattr(builtins, 'print', lambda *args, **kwargs: printed.append(' '.join(str(x) for x in args))) + monkeypatch.setattr(repl_mode.random, 'random', lambda: 0.4) + monkeypatch.setattr(repl_mode, '_contributors_picker', lambda: 'Alice') + monkeypatch.setattr(repl_mode, '_sponsors_picker', lambda: 'Carol') + monkeypatch.setattr(repl_mode, '_tips_picker', lambda: 'Tip') + + cli.verbosity = 0 + repl_mode._show_startup_banner(cli, cli.sqlexecute) + monkeypatch.setattr(repl_mode.random, 'random', lambda: 0.6) + repl_mode._show_startup_banner(cli, cli.sqlexecute) + cli.verbosity = -1 + repl_mode._show_startup_banner(cli, cli.sqlexecute) + assert any('Thanks to the contributor' in line for line in printed) + assert any('Tip — Tip' in line for line in printed) + + monkeypatch.setattr( + repl_mode, + 'render_prompt_string', + lambda mycli, string, render_counter: to_formatted_text('0123456') if string == cli.default_prompt else 'a\nb', + ) + cli.max_len_prompt = 5 + prompt_text = to_plain_text(repl_mode._get_prompt_message(cli, cast(Any, FakeApp(text='', render_counter=2)))) + assert prompt_text == 'a\nb' + assert cli.prompt_lines == 2 + + cli.last_prompt_message = repl_mode.ANSI('cached') + assert to_plain_text(repl_mode._get_prompt_message(cli, cast(Any, FakeApp(text='typing', render_counter=3)))) == 'cached' + + cli.prompt_format = 'custom' + cli.prompt_lines = 0 + monkeypatch.setattr(repl_mode, 'render_prompt_string', lambda mycli, string, render_counter: to_formatted_text('single')) + assert to_plain_text(repl_mode._get_prompt_message(cli, cast(Any, FakeApp(text='', render_counter=4)))) == 'single' + assert cli.prompt_lines == 1 + + assert repl_mode._get_continuation(cli, 4, 0, 0) == [('class:continuation', ' > ')] + cli.multiline_continuation_char = '' + assert repl_mode._get_continuation(cli, 4, 0, 0) == [('class:continuation', '')] + cli.multiline_continuation_char = None + assert repl_mode._get_continuation(cli, 4, 0, 0) == [('class:continuation', ' ')] + + +def test_repl_show_startup_banner_thanks_sponsor(monkeypatch: pytest.MonkeyPatch) -> None: + cli = make_repl_cli(SimpleNamespace(server_info='Server')) + cli.verbosity = 0 + printed: list[str] = [] + monkeypatch.setattr(builtins, 'print', lambda *args, **kwargs: printed.append(' '.join(str(x) for x in args))) + monkeypatch.setattr(repl_mode.random, 'random', lambda: 0.25) + monkeypatch.setattr(repl_mode, '_sponsors_picker', lambda: 'Carol') + + repl_mode._show_startup_banner(cli, cli.sqlexecute) + + assert any('Thanks to the sponsor' in line and 'Carol' in line for line in printed) + + +def test_prompt_toolbar_and_title_helpers(monkeypatch: pytest.MonkeyPatch) -> None: + class PromptCursor: + def __enter__(self) -> 'PromptCursor': + return self + + def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> Literal[False]: + return False + + class PromptConnection: + def cursor(self) -> PromptCursor: + return PromptCursor() + + sqlexecute = SimpleNamespace( + user='alice', + host='127.0.0.1', + dbname='db', + port=3307, + socket='/tmp/mysql.sock', + server_info=SimpleNamespace(species=SimpleNamespace(name='TiDB')), + conn=None, + ) + cli = make_repl_cli(sqlexecute) + cli.login_path = 'prod' + cli.login_path_as_host = True + cli.dsn_alias = 'dsn' + prompt = repl_mode.render_prompt_string(cli, r'\h|\H|\A|\y|\Y|\T|\w|\W', 0) + prompt_plain = to_plain_text(prompt) + assert prompt_plain == 'prod|prod|dsn|(none)|(none)|(none)|(none)|' + + sqlexecute.conn = PromptConnection() + cli.login_path_as_host = False + monkeypatch.setattr(repl_mode, 'get_uptime', lambda cur: 123) + monkeypatch.setattr(repl_mode, 'format_uptime', lambda uptime: f'uptime:{uptime}') + monkeypatch.setattr(repl_mode, 'get_ssl_version', lambda cur: 'TLSv1.3') + monkeypatch.setattr(repl_mode, 'get_warning_count', lambda cur: 7) + prompt = repl_mode.render_prompt_string(cli, r'\H|\y|\Y|\T|\w|\W', 1) + prompt_plain = to_plain_text(prompt) + assert prompt_plain == '127.0.0.1|123|uptime:123|TLSv1.3|7|7' + + cli.prompt_session = None + assert to_plain_text(repl_mode.get_custom_toolbar(cli, 'fmt')) == '' + cli.prompt_session = cast(Any, SimpleNamespace(app=None)) + assert to_plain_text(repl_mode.get_custom_toolbar(cli, 'fmt')) == '' + + cli.prompt_session = cast(Any, FakePromptSession()) + cli.last_custom_toolbar_message = repl_mode.ANSI('cached') + cli.prompt_session.app.current_buffer.text = 'typing' + assert repl_mode.get_custom_toolbar(cli, 'fmt') == cli.last_custom_toolbar_message + + cli.prompt_session.app.current_buffer.text = '' + monkeypatch.setattr(repl_mode, 'render_prompt_string', lambda mycli, string, render_counter: f'title:{string}') + assert 'title:fmt' in str(repl_mode.get_custom_toolbar(cli, 'fmt')) + + cli.terminal_tab_title_format = 'tab' + cli.terminal_window_title_format = 'window' + cli.multiplex_window_title_format = 'mux-window' + cli.multiplex_pane_title_format = 'mux-pane' + monkeypatch.setattr(repl_mode, 'sanitize_terminal_title', lambda title: title.upper()) + monkeypatch.setattr(repl_mode.sys.stderr, 'isatty', lambda: True) + printed: list[str] = [] + monkeypatch.setattr(builtins, 'print', lambda *args, **kwargs: printed.append(args[0])) + tmux_calls: list[tuple[Any, ...]] = [] + monkeypatch.setattr(repl_mode.subprocess, 'run', lambda *args, **kwargs: tmux_calls.append(args)) + monkeypatch.setenv('TMUX', '1') + repl_mode.set_all_external_titles(cli) + assert printed[0].startswith('\x1b]1;TITLE:TAB') + assert printed[1].startswith('\x1b]2;TITLE:WINDOW') + assert printed[2].startswith('\x1b]2;TITLE:MUX-PANE') + assert tmux_calls + + monkeypatch.setattr(repl_mode.sys.stderr, 'isatty', lambda: False) + repl_mode.set_external_terminal_tab_title(cli) + repl_mode.set_external_terminal_window_title(cli) + repl_mode.set_external_multiplex_pane_title(cli) + monkeypatch.delenv('TMUX', raising=False) + repl_mode.set_external_multiplex_window_title(cli) + monkeypatch.setenv('TMUX', '1') + monkeypatch.setattr(repl_mode.subprocess, 'run', lambda *args, **kwargs: (_ for _ in ()).throw(FileNotFoundError())) + repl_mode.set_external_multiplex_window_title(cli) + + +def test_prompt_and_title_helper_early_returns_and_remaining_prompt_branches(monkeypatch: pytest.MonkeyPatch) -> None: + class PromptCursor: + def __enter__(self) -> 'PromptCursor': + return self + + def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> Literal[False]: + return False + + class PromptConnection: + def cursor(self) -> PromptCursor: + return PromptCursor() + + cli = make_repl_cli( + SimpleNamespace( + user='alice', + host=None, + dbname='db', + port=3306, + socket=None, + server_info=SimpleNamespace(species=SimpleNamespace(name='MySQL')), + conn=PromptConnection(), + ) + ) + cli.prompt_session = cast(Any, FakePromptSession()) + + monkeypatch.setattr(repl_mode, 'get_uptime', lambda cur: 123) + monkeypatch.setattr(repl_mode, 'format_uptime', lambda uptime: f'uptime:{uptime}') + monkeypatch.setattr(repl_mode, 'get_ssl_version', lambda cur: 'TLSv1.3') + monkeypatch.setattr(repl_mode, 'get_warning_count', lambda cur: 7) + + prompt = repl_mode.render_prompt_string(cli, r'\h|\H|\y|\Y', 0) + prompt_plain = to_plain_text(prompt) + assert prompt_plain == f'{repl_mode.DEFAULT_HOST}|{repl_mode.DEFAULT_HOST}|123|uptime:123' + + prompt = repl_mode.render_prompt_string(cli, r'\h|\H|\w|\W', 1) + prompt_plain = to_plain_text(prompt) + assert prompt_plain == f'{repl_mode.DEFAULT_HOST}|{repl_mode.DEFAULT_HOST}|7|7' + + prompt = repl_mode.render_prompt_string(cli, r'\h|\H|\T', 2) + prompt_plain = to_plain_text(prompt) + assert prompt_plain == f'{repl_mode.DEFAULT_HOST}|{repl_mode.DEFAULT_HOST}|TLSv1.3' + + monkeypatch.setattr(repl_mode.sys.stderr, 'isatty', lambda: True) + monkeypatch.setattr(builtins, 'print', lambda *args, **kwargs: (_ for _ in ()).throw(AssertionError('unexpected print'))) + monkeypatch.setattr( + repl_mode.subprocess, + 'run', + lambda *args, **kwargs: (_ for _ in ()).throw(AssertionError('unexpected tmux call')), + ) + + cli.terminal_tab_title_format = '' + repl_mode.set_external_terminal_tab_title(cli) + cli.terminal_tab_title_format = 'tab' + cli.prompt_session = None + repl_mode.set_external_terminal_tab_title(cli) + + cli.prompt_session = cast(Any, FakePromptSession()) + cli.terminal_window_title_format = '' + repl_mode.set_external_terminal_window_title(cli) + cli.terminal_window_title_format = 'window' + cli.prompt_session = None + repl_mode.set_external_terminal_window_title(cli) + + cli.prompt_session = cast(Any, FakePromptSession()) + cli.multiplex_window_title_format = '' + repl_mode.set_external_multiplex_window_title(cli) + cli.multiplex_window_title_format = 'mux-window' + monkeypatch.setenv('TMUX', '1') + cli.prompt_session = None + repl_mode.set_external_multiplex_window_title(cli) + + cli.prompt_session = cast(Any, FakePromptSession()) + cli.multiplex_pane_title_format = '' + repl_mode.set_external_multiplex_pane_title(cli) + cli.multiplex_pane_title_format = 'mux-pane' + monkeypatch.delenv('TMUX', raising=False) + repl_mode.set_external_multiplex_pane_title(cli) + monkeypatch.setenv('TMUX', '1') + cli.prompt_session = None + repl_mode.set_external_multiplex_pane_title(cli) + + +def test_maybe_html_escape() -> None: + assert repl_mode.maybe_html_escape('plain', False) == 'plain' + assert repl_mode.maybe_html_escape('a&b<1>', True) == 'a&b<1>' + + +def test_render_prompt_string_html() -> None: + repl_mode.render_prompt_string.cache_clear() + + cli = make_repl_cli( + SimpleNamespace( + user='ab', + host='db.example.com', + dbname='nameprod', + port=3306, + socket=None, + server_info=SimpleNamespace(species=SimpleNamespace(name='MySQL')), + conn=None, + ) + ) + cli.dsn_alias = 'aliasone' + + html_prompt = repl_mode.render_prompt_string(cli, r'\\u@\d|\A\', 0) + assert to_plain_text(html_prompt) == 'ab@nameprod|aliasone' + + bad_html_prompt = repl_mode.render_prompt_string(cli, r'\\u', 1) + assert to_plain_text(bad_html_prompt) == '(cannot parse HTML prompt string)' + + ansi_prompt = repl_mode.render_prompt_string(cli, r'\x1b[31mred\x1b[0m', 2) + assert to_plain_text(ansi_prompt) == 'red' + + +def test_render_prompt_string_ansi() -> None: + repl_mode.render_prompt_string.cache_clear() + + cli = make_repl_cli( + SimpleNamespace( + user='ab', + host='db.example.com', + dbname='nameprod', + port=3306, + socket=None, + server_info=SimpleNamespace(species=SimpleNamespace(name='MySQL')), + conn=None, + ) + ) + cli.dsn_alias = 'aliasone' + + ansi_prompt = repl_mode.render_prompt_string(cli, r'\x1b[31mred\x1b[0m', 2) + assert to_plain_text(ansi_prompt) == 'red' + + +def test_output_results_covers_watch_warning_timing_beep_and_interrupts(monkeypatch: pytest.MonkeyPatch) -> None: + class FakeSQLExecute: + def run(self, text: str) -> list[SQLResult]: + assert text == 'SHOW WARNINGS' + return [SQLResult(status='warning', rows=[('warn',)])] + + cli = make_repl_cli(FakeSQLExecute()) + cli.auto_vertical_output = True + cli.prompt_session = FakePromptSession(columns=91) + cli.beep_after_seconds = 0.1 + state = repl_mode.ReplState() + format_widths: list[int | None] = [] + + def format_sqlresult(result: SQLResult, **kwargs: Any) -> Iterator[str]: + format_widths.append(kwargs.get('max_width')) + return iter([result.status_plain or 'row']) + + cli.format_sqlresult = format_sqlresult + time_values = iter([0.2, 1.0, 2.0, 3.0, 3.2]) + monkeypatch.setattr(repl_mode.time, 'time', lambda: next(time_values)) + monkeypatch.setattr(repl_mode.special, 'is_expanded_output', lambda: False) + monkeypatch.setattr(repl_mode.special, 'is_redirected', lambda: False) + monkeypatch.setattr(repl_mode.special, 'is_show_warnings_enabled', lambda: True) + monkeypatch.setattr(repl_mode.special, 'is_timing_enabled', lambda: True) + monkeypatch.setattr(repl_mode, 'Cursor', FakeCursorBase) + monkeypatch.setattr(repl_mode, 'is_select', lambda status: False) + monkeypatch.setattr(repl_mode, 'is_mutating', lambda status: status == 'mut') + + results = sqlresult_generator( + SQLResult(status='watch', command={'name': 'watch', 'seconds': '1'}), + SQLResult(status='mut', rows=cast(Any, FakeCursorBase(rowcount=1, warning_count=1))), + ) + + repl_mode._output_results(cli, state, results, start=0.0) + + assert state.mutating is True + assert format_widths[:2] == [91, 91] + assert cli.prompt_session.output.bell_count == 2 + assert '' in cli.echo_calls + assert any(is_warnings_style is True for _, _, is_warnings_style in cli.output_calls) + assert any(is_warnings_style is False for _, is_warnings_style in cli.timing_calls) + assert any(is_warnings_style is True for _, is_warnings_style in cli.timing_calls) + + cli_interrupt = make_repl_cli(SimpleNamespace()) + cli_interrupt.echo = lambda message, **kwargs: ( + (_ for _ in ()).throw(KeyboardInterrupt()) if message == '' else cli_interrupt.echo_calls.append(str(message)) + ) + cli_interrupt.output = lambda formatted, result, is_warnings_style=False: (_ for _ in ()).throw(KeyboardInterrupt()) + monkeypatch.setattr(repl_mode.special, 'is_timing_enabled', lambda: False) + monkeypatch.setattr(repl_mode, 'is_select', lambda status: False) + monkeypatch.setattr(repl_mode.time, 'time', lambda: 0.0) + repl_mode._output_results( + cli_interrupt, + repl_mode.ReplState(), + sqlresult_generator(SQLResult(status='first'), SQLResult(status='second')), + start=0.0, + ) + + +def test_output_results_handles_abort_default_width_and_bad_watch(monkeypatch: pytest.MonkeyPatch) -> None: + cli = make_repl_cli(SimpleNamespace()) + cli.auto_vertical_output = True + widths: list[int | None] = [] + + def format_sqlresult_with_width(result: SQLResult, **kwargs: Any) -> Iterator[str]: + widths.append(kwargs.get('max_width')) + return iter([result.status_plain or 'row']) + + cli.format_sqlresult = format_sqlresult_with_width + monkeypatch.setattr(repl_mode, 'Cursor', FakeCursorBase) + monkeypatch.setattr(repl_mode, 'is_select', lambda status: status == 'select') + monkeypatch.setattr(repl_mode, 'confirm', lambda text: False) + repl_mode._output_results( + cli, + repl_mode.ReplState(), + sqlresult_generator(SQLResult(status='select', rows=cast(Any, FakeCursorBase(rowcount=1001)))), + start=0.0, + ) + assert 'The result set has more than 1000 rows.' in cli.echo_calls + assert 'Aborted!' in cli.echo_calls + + repl_mode._output_results( + cli, + repl_mode.ReplState(), + sqlresult_generator(SQLResult(status='ok')), + start=0.0, + ) + assert widths[-1] == repl_mode.DEFAULT_WIDTH + + monkeypatch.setattr(repl_mode, 'is_select', lambda status: False) + with pytest.raises(SystemExit): + repl_mode._output_results( + cli, + repl_mode.ReplState(), + sqlresult_generator( + SQLResult(status='watch', command={'name': 'watch', 'seconds': '1'}), + SQLResult(status='watch', command={'name': 'watch', 'seconds': 'bad'}), + ), + start=0.0, + ) + + +def test_keepalive_hook_covers_threshold_and_errors() -> None: + cli = make_repl_cli(SimpleNamespace(conn=FakeConnection())) + repl_mode._keepalive_hook(cli, None) + assert cli._keepalive_counter == 0 + + cli.keepalive_ticks = 0 + repl_mode._keepalive_hook(cli, None) + assert cli._keepalive_counter == 0 + + cli.keepalive_ticks = 1 + repl_mode._keepalive_hook(cli, None) + assert cli._keepalive_counter == 1 + repl_mode._keepalive_hook(cli, None) + assert cli._keepalive_counter == 0 + assert cli.sqlexecute.conn.ping_calls == [False] + + cli.sqlexecute.conn = FakeConnection(ping_exc=RuntimeError('boom')) + repl_mode._keepalive_hook(cli, None) + repl_mode._keepalive_hook(cli, None) + assert any('keepalive ping error' in call[0][0] for call in cli.logger.debug_calls) + + +def test_build_prompt_session_covers_toolbar_modes_and_editing_modes(monkeypatch: pytest.MonkeyPatch) -> None: + captured_kwargs: list[dict[str, Any]] = [] + toolbar_help: list[bool] = [] + + def fake_prompt_session(**kwargs: Any) -> FakePromptSession: + captured_kwargs.append(kwargs) + return FakePromptSession() + + monkeypatch.setattr(repl_mode, 'PromptSession', fake_prompt_session) + monkeypatch.setattr(repl_mode, 'style_factory_ptoolkit', lambda *args, **kwargs: 'style') + monkeypatch.setattr(repl_mode, 'cli_is_multiline', lambda mycli: False) + + def fake_toolbar_tokens(mycli: Any, show_help: Any, fmt: str, custom_toolbar: Any) -> str: + toolbar_help.append(show_help()) + assert callable(custom_toolbar) + return 'toolbar' + + monkeypatch.setattr(repl_mode, 'create_toolbar_tokens_func', fake_toolbar_tokens) + + cli = make_repl_cli(SimpleNamespace()) + state = repl_mode.ReplState() + cli.toolbar_format = 'none' + cli.key_bindings = 'vi' + cli.wider_completion_menu = True + repl_mode._build_prompt_session(cli, state, history=cast(Any, 'history'), key_bindings=cast(Any, 'bindings')) + first_kwargs = captured_kwargs[-1] + assert first_kwargs['bottom_toolbar'] is None + assert first_kwargs['complete_style'] == repl_mode.CompleteStyle.MULTI_COLUMN + assert first_kwargs['editing_mode'] == repl_mode.EditingMode.VI + assert cli.prompt_session.app.ttimeoutlen == cli.vi_ttimeoutlen + + cli.toolbar_format = 'default' + cli.key_bindings = 'emacs' + cli.wider_completion_menu = False + state.iterations = 0 + repl_mode._build_prompt_session(cli, state, history=cast(Any, 'history'), key_bindings=cast(Any, 'bindings')) + latest_kwargs = captured_kwargs[-1] + assert latest_kwargs['bottom_toolbar'] == 'toolbar' + assert latest_kwargs['complete_style'] == repl_mode.CompleteStyle.COLUMN + assert latest_kwargs['editing_mode'] == repl_mode.EditingMode.EMACS + assert toolbar_help == [True] + assert cli.prompt_session.app.ttimeoutlen == cli.emacs_ttimeoutlen + assert latest_kwargs['prompt_continuation'](4, 0, 0) == [('class:continuation', ' > ')] + + +def test_one_iteration_handles_prompt_interrupt_empty_editor_clip_and_clip_true(monkeypatch: pytest.MonkeyPatch) -> None: + patch_repl_runtime_defaults(monkeypatch) + cli = make_repl_cli(SimpleNamespace(run=lambda text: iter([SQLResult(status='ok')]), conn=FakeConnection())) + cli.keepalive_ticks = 1 + cli.prompt_session = FakePromptSession([KeyboardInterrupt(), ' ', 'edit-error', 'clip-error', 'clip-stop']) + + repl_mode._one_iteration(cli, repl_mode.ReplState()) + assert cli.query_history == [] + + repl_mode._one_iteration(cli, repl_mode.ReplState()) + assert cli.query_history == [] + inputhook = cli.prompt_session.prompt_calls[-1]['inputhook'] + assert inputhook is not None + inputhook(None) + + monkeypatch.setattr(repl_mode, 'handle_editor_command', lambda *args: (_ for _ in ()).throw(RuntimeError('edit boom'))) + repl_mode._one_iteration(cli, repl_mode.ReplState()) + assert 'edit boom' in cli.echo_calls[-1] + + monkeypatch.setattr(repl_mode, 'handle_editor_command', lambda mycli, text, inputhook, loaded_message_fn: text) + monkeypatch.setattr(repl_mode, 'handle_clip_command', lambda mycli, text: (_ for _ in ()).throw(RuntimeError('clip boom'))) + repl_mode._one_iteration(cli, repl_mode.ReplState()) + assert 'clip boom' in cli.echo_calls[-1] + + monkeypatch.setattr(repl_mode, 'handle_clip_command', lambda mycli, text: True) + repl_mode._one_iteration(cli, repl_mode.ReplState()) + assert cli.query_history == [] + + +def test_one_iteration_covers_llm_paths(monkeypatch: pytest.MonkeyPatch) -> None: + patch_repl_runtime_defaults(monkeypatch) + click_output: list[str] = [] + monkeypatch.setattr(repl_mode.click, 'echo', lambda message='', **kwargs: click_output.append(str(message))) + monkeypatch.setattr(repl_mode.special, 'is_timing_enabled', lambda: True) + monkeypatch.setattr(repl_mode.special, 'is_llm_command', lambda text: text.startswith('\\llm')) + + class FakeSQLExecute: + def __init__(self) -> None: + self.dbname = 'db' + self.conn = FakeConnection(cursor_value='cursor') + + def run(self, text: str) -> Iterator[SQLResult]: + return iter([SQLResult(status=f'ran:{text}')]) + + monkeypatch.setattr( + repl_mode.special, + 'handle_llm', + lambda text, cur, dbname, field_truncate, section_truncate: ('context', 'select 1', 1.25), + ) + cli = make_repl_cli(FakeSQLExecute()) + cli.prompt_session = FakePromptSession(['\\llm ask', 'select 1']) + repl_mode._one_iteration( + cli, + repl_mode.ReplState(), + ) + assert click_output[:3] == ['LLM Response:', 'context', '---'] + assert cli.output_calls[0][0] == ['None', 'ran:select 1'] + + cli_finish = make_repl_cli(FakeSQLExecute()) + cli_finish.prompt_session = FakePromptSession(['\\llm finish']) + cli_finish.format_sqlresult = lambda result, **kwargs: iter([result.status_plain or 'row']) + monkeypatch.setattr( + repl_mode.special, + 'handle_llm', + lambda *args, **kwargs: (_ for _ in ()).throw(repl_mode.special.FinishIteration(iter([SQLResult(status='done')]))), + ) + repl_mode._one_iteration(cli_finish, repl_mode.ReplState()) + assert cli_finish.output_calls[0][0] == ['done'] + + cli_empty = make_repl_cli(FakeSQLExecute()) + cli_empty.prompt_session = FakePromptSession(['\\llm empty']) + monkeypatch.setattr( + repl_mode.special, + 'handle_llm', + lambda *args, **kwargs: (_ for _ in ()).throw(repl_mode.special.FinishIteration(None)), + ) + repl_mode._one_iteration(cli_empty, repl_mode.ReplState()) + assert cli_empty.output_calls == [] + + cli_err = make_repl_cli(FakeSQLExecute()) + cli_err.prompt_session = FakePromptSession(['\\llm err']) + monkeypatch.setattr( + repl_mode.special, + 'handle_llm', + lambda *args, **kwargs: (_ for _ in ()).throw(RuntimeError('llm boom')), + ) + repl_mode._one_iteration(cli_err, repl_mode.ReplState()) + assert 'llm boom' in cli_err.echo_calls[-1] + + cli_interrupt = make_repl_cli(FakeSQLExecute()) + cli_interrupt.prompt_session = FakePromptSession(['\\llm stop']) + monkeypatch.setattr( + repl_mode.special, + 'handle_llm', + lambda *args, **kwargs: (_ for _ in ()).throw(KeyboardInterrupt()), + ) + repl_mode._one_iteration(cli_interrupt, repl_mode.ReplState()) + assert cli_interrupt.output_calls == [] + + cli_quiet = make_repl_cli(FakeSQLExecute()) + cli_quiet.prompt_session = FakePromptSession(['\\llm quiet', 'select 2']) + monkeypatch.setattr(repl_mode.special, 'is_timing_enabled', lambda: False) + monkeypatch.setattr( + repl_mode.special, + 'handle_llm', + lambda text, cur, dbname, field_truncate, section_truncate: ('', 'select 2', 0.5), + ) + repl_mode._one_iteration(cli_quiet, repl_mode.ReplState()) + assert cli_quiet.output_calls[0][0] == ['None', 'ran:select 2'] + + +@pytest.mark.parametrize( + 'text, expected', + [ + ('', True), + (' ', True), + ("ALTER USER 'root'@'localhost' IDENTIFIED BY 'new'", True), + ('alter user root identified by "pw"', True), + ("SET PASSWORD = 'newpass'", True), + ("set password = 'newpass'", True), + ('quit', True), + ('exit', True), + ('\\q', True), + ('SELECT 1', False), + ('DROP TABLE t', False), + ('USE mydb', False), + ('SHOW DATABASES', False), + ], +) +def test_is_sandbox_allowed(text: str, expected: bool) -> None: + from mycli.packages.sql_utils import is_sandbox_allowed + + assert is_sandbox_allowed(text) is expected + + +@pytest.mark.parametrize( + 'text, expected', + [ + ("ALTER USER 'root'@'localhost' IDENTIFIED BY 'new'", True), + ("SET PASSWORD = 'newpass'", True), + ('SELECT 1', False), + ('quit', False), + ], +) +def test_is_password_change(text: str, expected: bool) -> None: + from mycli.packages.sql_utils import is_password_change + + assert is_password_change(text) is expected + + +@pytest.mark.parametrize( + 'text, expected', + [ + ("ALTER USER 'root'@'localhost' IDENTIFIED BY 'newpass'", 'newpass'), + ("SET PASSWORD = 'secret123'", 'secret123'), + ("ALTER USER root IDENTIFIED BY 'p@ss w0rd!'", 'p@ss w0rd!'), + ('ALTER USER root IDENTIFIED WITH mysql_native_password', None), + ('SELECT 1', None), + ], +) +def test_extract_new_password(text: str, expected: str | None) -> None: + from mycli.packages.sql_utils import extract_new_password + + assert extract_new_password(text) == expected + + +def test_one_iteration_blocks_disallowed_in_sandbox_mode(monkeypatch: pytest.MonkeyPatch) -> None: + patch_repl_runtime_defaults(monkeypatch) + + class FakeSQLExecute: + def __init__(self) -> None: + self.dbname = 'db' + self.connection_id = 0 + + def run(self, text: str) -> Iterator[SQLResult]: + return iter([SQLResult(status=f'ran:{text}')]) + + cli = make_repl_cli(FakeSQLExecute()) + cli.sandbox_mode = True + + repl_mode._one_iteration(cli, repl_mode.ReplState(), 'SELECT 1') + assert any('ERROR 1820' in msg for msg in cli.echo_calls) + assert not cli.query_history + + +def test_one_iteration_allows_alter_user_in_sandbox_mode(monkeypatch: pytest.MonkeyPatch) -> None: + patch_repl_runtime_defaults(monkeypatch) + + class FakeSQLExecute: + def __init__(self) -> None: + self.dbname = 'db' + self.connection_id = 0 + self.password = 'old' + self.connect_calls: list[bool] = [] + + def connect(self) -> None: + self.connect_calls.append(True) + + def run(self, text: str) -> Iterator[SQLResult]: + return iter([SQLResult(status='OK')]) + + sqlexecute = FakeSQLExecute() + cli = make_repl_cli(sqlexecute) + cli.sandbox_mode = True + monkeypatch.setattr(repl_mode, 'is_mutating', lambda status: False) + + repl_mode._one_iteration(cli, repl_mode.ReplState(), "ALTER USER 'root'@'localhost' IDENTIFIED BY 'newpass'") + assert cli.sandbox_mode is False + assert sqlexecute.password == 'newpass' + assert sqlexecute.connect_calls == [True] + assert any('Reconnected' in msg for msg in cli.echo_calls) + + +def test_one_iteration_sandbox_reconnect_failure(monkeypatch: pytest.MonkeyPatch) -> None: + patch_repl_runtime_defaults(monkeypatch) + + class FakeSQLExecute: + def __init__(self) -> None: + self.dbname = 'db' + self.connection_id = 0 + self.password = 'old' + + def connect(self) -> None: + raise RuntimeError('connection refused') + + def run(self, text: str) -> Iterator[SQLResult]: + return iter([SQLResult(status='OK')]) + + sqlexecute = FakeSQLExecute() + cli = make_repl_cli(sqlexecute) + cli.sandbox_mode = True + monkeypatch.setattr(repl_mode, 'is_mutating', lambda status: False) + + repl_mode._one_iteration(cli, repl_mode.ReplState(), "ALTER USER 'root'@'localhost' IDENTIFIED BY 'newpass'") + assert cli.sandbox_mode is False + assert any('reconnection failed' in msg for msg in cli.echo_calls) + + +def test_one_iteration_enters_sandbox_mode_on_must_change_password_error(monkeypatch: pytest.MonkeyPatch) -> None: + patch_repl_runtime_defaults(monkeypatch) + + class FakeSQLExecute: + dbname = 'db' + connection_id = 0 + + def run(self, text: str) -> Iterator[SQLResult]: + raise pymysql.OperationalError(repl_mode.ER_MUST_CHANGE_PASSWORD, 'must change password') + + cli = make_repl_cli(FakeSQLExecute()) + + repl_mode._one_iteration(cli, repl_mode.ReplState(), 'SELECT 1') + + assert cli.sandbox_mode is True + assert any('ERROR 1820' in msg for msg in cli.echo_calls) + assert cli.query_history[-1].query == 'SELECT 1' + assert cli.query_history[-1].successful is False + + +def test_one_iteration_covers_redirect_destructive_success_refresh_and_logfile(monkeypatch: pytest.MonkeyPatch) -> None: + patch_repl_runtime_defaults(monkeypatch) + + class FakeSQLExecute: + def __init__(self) -> None: + self.dbname: str | None = 'db' + self.connection_id = 0 + self.calls: list[str] = [] + + def connect(self) -> None: + self.calls.append('connect') + + def run(self, text: str) -> Iterator[SQLResult]: + self.calls.append(text) + return iter([SQLResult(status='DROP 1')]) + + sqlexecute = FakeSQLExecute() + cli = make_repl_cli(sqlexecute) + cli.logfile = False + cli.destructive_warning = True + monkeypatch.setattr(repl_mode, 'is_redirect_command', lambda text: text == 'redirect') + monkeypatch.setattr(repl_mode, 'get_redirect_components', lambda text: ('dropdb', 'tee', '>', 'out.txt')) + redirects: list[tuple[Any, ...]] = [] + monkeypatch.setattr(repl_mode.special, 'set_redirect', lambda *args: redirects.append(args)) + monkeypatch.setattr( + repl_mode, + 'confirm_destructive_query', + lambda keywords, text: None if text == 'dropdb' else (True if text == 'approved' else False), + ) + monkeypatch.setattr(repl_mode, 'is_dropping_database', lambda text, dbname: text == 'dropdb') + monkeypatch.setattr(repl_mode, 'need_completion_refresh', lambda text: text == 'dropdb') + monkeypatch.setattr(repl_mode, 'need_completion_reset', lambda text: text == 'dropdb') + monkeypatch.setattr(repl_mode, 'is_mutating', lambda status: True) + + repl_mode._one_iteration(cli, repl_mode.ReplState(), 'redirect') + assert redirects == [('tee', '>', 'out.txt')] + assert cli.refresh_calls == [True] + assert cli.query_history[-1].query == 'dropdb' + assert cli.query_history[-1].successful is True + assert cli.query_history[-1].mutating is True + assert sqlexecute.dbname is None + assert sqlexecute.calls == ['dropdb', 'connect'] + assert 'Warning: This query was not logged.' in cli.echo_calls + + repl_mode._one_iteration(cli, repl_mode.ReplState(), 'approved') + assert 'Your call!' in cli.echo_calls + + repl_mode._one_iteration(cli, repl_mode.ReplState(), 'denied') + assert 'Wise choice!' in cli.echo_calls + + +def test_one_iteration_covers_reconnect_and_error_paths(monkeypatch: pytest.MonkeyPatch) -> None: + patch_repl_runtime_defaults(monkeypatch) + + class InterfaceSQLExecute: + def __init__(self) -> None: + self.dbname: str | None = 'db' + self.connection_id = 0 + self.calls: list[str] = [] + + def run(self, text: str) -> Iterator[SQLResult]: + self.calls.append(text) + if text == 'iface' and self.calls.count('iface') == 1: + raise pymysql.err.InterfaceError() + return iter([SQLResult(status=f'ok:{text}')]) + + interface_sql = InterfaceSQLExecute() + cli_interface = make_repl_cli(interface_sql) + interface_reconnect_calls: list[str] = [] + interface_results = iter([True]) + + def interface_reconnect(database: str = '') -> bool: + interface_reconnect_calls.append(database) + return next(interface_results) + + cli_interface.reconnect = interface_reconnect + + repl_mode._one_iteration(cli_interface, repl_mode.ReplState(), 'iface') + assert interface_sql.calls.count('iface') == 2 + assert cli_interface.query_history[-1].query == 'iface' + assert interface_reconnect_calls == [''] + + cli_interface_false = make_repl_cli(InterfaceSQLExecute()) + false_calls: list[str] = [] + + def interface_reconnect_false(database: str = '') -> bool: + false_calls.append(database) + return False + + cli_interface_false.reconnect = interface_reconnect_false + repl_mode._one_iteration(cli_interface_false, repl_mode.ReplState(), 'iface') + assert false_calls == [''] + + class ErrorSQLExecute: + def __init__(self) -> None: + self.dbname: str | None = 'db' + self.connection_id = 0 + self.calls: list[str] = [] + + def run(self, text: str) -> Iterator[SQLResult]: + self.calls.append(text) + if text == 'oplost' and self.calls.count('oplost') == 1: + raise pymysql.OperationalError(2003, 'lost') + if text == 'opbad': + raise pymysql.OperationalError(9999, 'bad op') + if text == 'nyi': + raise NotImplementedError() + if text == 'boom': + raise RuntimeError('boom') + if text == 'eof': + raise EOFError() + return iter([SQLResult(status=f'ok:{text}')]) + + error_sql = ErrorSQLExecute() + cli_error = make_repl_cli(error_sql) + error_reconnect_calls: list[str] = [] + + def error_reconnect(database: str = '') -> bool: + error_reconnect_calls.append(database) + return True + + cli_error.reconnect = error_reconnect + + repl_mode._one_iteration(cli_error, repl_mode.ReplState(), 'oplost') + assert error_sql.calls.count('oplost') == 2 + repl_mode._one_iteration(cli_error, repl_mode.ReplState(), 'opbad') + repl_mode._one_iteration(cli_error, repl_mode.ReplState(), 'nyi') + repl_mode._one_iteration(cli_error, repl_mode.ReplState(), 'boom') + assert any('bad op' in line for line in cli_error.echo_calls) + assert 'Not Yet Implemented.' in cli_error.echo_calls + assert any('boom' in line for line in cli_error.echo_calls) + assert error_reconnect_calls == [''] + + cli_error_false = make_repl_cli(ErrorSQLExecute()) + false_reconnect_calls: list[str] = [] + + def error_reconnect_false(database: str = '') -> bool: + false_reconnect_calls.append(database) + return False + + cli_error_false.reconnect = error_reconnect_false + repl_mode._one_iteration(cli_error_false, repl_mode.ReplState(), 'oplost') + assert false_reconnect_calls == [''] + + with pytest.raises(EOFError): + repl_mode._one_iteration(cli_error, repl_mode.ReplState(), 'eof') + + +def test_one_iteration_reraises_eoferror(monkeypatch: pytest.MonkeyPatch) -> None: + patch_repl_runtime_defaults(monkeypatch) + + class EofSQLExecute: + dbname = 'db' + connection_id = 0 + + def run(self, text: str) -> Iterator[SQLResult]: + raise EOFError() + + with pytest.raises(EOFError): + repl_mode._one_iteration(make_repl_cli(EofSQLExecute()), repl_mode.ReplState(), 'eof') + + +def test_one_iteration_covers_cancel_paths_and_redirect_error(monkeypatch: pytest.MonkeyPatch) -> None: + patch_repl_runtime_defaults(monkeypatch) + + class FakeSQLExecute: + def __init__(self) -> None: + self.dbname = 'db' + self.connection_id = 0 + + def connect(self) -> None: + return None + + def run(self, text: str) -> Iterator[SQLResult]: + if text == 'cancel-ok': + self.connection_id = 7 + raise KeyboardInterrupt() + if text == 'kill 7': + return iter([SQLResult(status='OK')]) + if text == 'cancel-fail': + self.connection_id = 8 + raise KeyboardInterrupt() + if text == 'kill 8': + return iter([SQLResult(status='failed')]) + if text == 'cancel-error': + self.connection_id = 9 + raise KeyboardInterrupt() + if text == 'kill 9': + raise RuntimeError('kill failed') + if text == 'cancel-missing': + self.connection_id = 0 + raise KeyboardInterrupt() + return iter([SQLResult(status='ok')]) + + cli = make_repl_cli(FakeSQLExecute()) + monkeypatch.setattr(repl_mode, 'is_redirect_command', lambda text: text == 'redirect-bad') + monkeypatch.setattr(repl_mode, 'get_redirect_components', lambda text: ('sql', 'tee', '>', 'out.txt')) + monkeypatch.setattr(repl_mode.special, 'set_redirect', lambda *args: (_ for _ in ()).throw(RuntimeError('redirect boom'))) + repl_mode._one_iteration(cli, repl_mode.ReplState(), 'redirect-bad') + assert 'redirect boom' in cli.echo_calls[-1] + + repl_mode._one_iteration(cli, repl_mode.ReplState(), 'cancel-ok') + repl_mode._one_iteration(cli, repl_mode.ReplState(), 'cancel-fail') + repl_mode._one_iteration(cli, repl_mode.ReplState(), 'cancel-error') + repl_mode._one_iteration(cli, repl_mode.ReplState(), 'cancel-missing') + assert 'Cancelled query id: 7' in cli.echo_calls + assert any('Failed to confirm query cancellation' in line for line in cli.echo_calls) + assert any('Encountered error while cancelling query' in line for line in cli.echo_calls) + assert 'Did not get a connection id, skip cancelling query' in cli.echo_calls + + +def test_main_repl_covers_setup_loop_and_goodbye(monkeypatch: pytest.MonkeyPatch) -> None: + cli = make_repl_cli(SimpleNamespace()) + cli.verbosity = 0 + cli.smart_completion = True + loop_iterations: list[int] = [] + monkeypatch.setattr(repl_mode, '_create_history', lambda mycli: 'history') + monkeypatch.setattr(repl_mode, 'mycli_bindings', lambda mycli: 'bindings') + monkeypatch.setattr(repl_mode, '_show_startup_banner', lambda mycli, sqlexecute: None) + monkeypatch.setattr( + repl_mode, + '_build_prompt_session', + lambda mycli, state, history, key_bindings: setattr(mycli, 'prompt_session', FakePromptSession()), + ) + + def fake_one_iteration(mycli: Any, state: repl_mode.ReplState) -> None: + loop_iterations.append(state.iterations) + if len(loop_iterations) == 2: + raise EOFError() + + closed: list[bool] = [] + monkeypatch.setattr(repl_mode, '_one_iteration', fake_one_iteration) + monkeypatch.setattr(repl_mode.special, 'close_tee', lambda: closed.append(True)) + monkeypatch.setattr(repl_mode, 'set_all_external_titles', lambda mycli: setattr(mycli, 'title_calls', mycli.title_calls + 1)) + + repl_mode.main_repl(cli) + + assert cli.pager_configured == 1 + assert cli.refresh_calls == [False] + assert cli.title_calls == 1 + assert loop_iterations == [0, 1] + assert closed == [True] + assert cli.echo_calls[-1] == 'Goodbye!' + + +def test_main_repl_covers_no_refresh_and_quiet_exit(monkeypatch: pytest.MonkeyPatch) -> None: + cli = make_repl_cli(SimpleNamespace()) + cli.verbosity = -1 + cli.smart_completion = False + monkeypatch.setattr(repl_mode, '_create_history', lambda mycli: 'history') + monkeypatch.setattr(repl_mode, 'mycli_bindings', lambda mycli: 'bindings') + monkeypatch.setattr(repl_mode, '_show_startup_banner', lambda mycli, sqlexecute: None) + monkeypatch.setattr( + repl_mode, + '_build_prompt_session', + lambda mycli, state, history, key_bindings: setattr(mycli, 'prompt_session', FakePromptSession()), + ) + monkeypatch.setattr(repl_mode, '_one_iteration', lambda mycli, state: (_ for _ in ()).throw(EOFError())) + monkeypatch.setattr(repl_mode.special, 'close_tee', lambda: None) + monkeypatch.setattr(repl_mode, 'set_all_external_titles', lambda mycli: setattr(mycli, 'title_calls', mycli.title_calls + 1)) + + repl_mode.main_repl(cli) + + assert cli.refresh_calls == [] + assert cli.echo_calls == [] + + +def test_output_results_covers_remaining_watch_select_and_warning_branches(monkeypatch: pytest.MonkeyPatch) -> None: + class WarninglessSQLExecute: + def run(self, text: str) -> list[SQLResult]: + assert text == 'SHOW WARNINGS' + return [] + + cli = make_repl_cli(WarninglessSQLExecute()) + cli.show_warnings = True + cli.auto_vertical_output = False + cli.prompt_session = FakePromptSession(columns=77) + monkeypatch.setattr(repl_mode, 'Cursor', FakeCursorBase) + monkeypatch.setattr(repl_mode, 'is_mutating', lambda status: False) + monkeypatch.setattr(repl_mode, 'confirm', lambda text: True) + monkeypatch.setattr(repl_mode.special, 'is_expanded_output', lambda: False) + monkeypatch.setattr(repl_mode.special, 'is_redirected', lambda: False) + monkeypatch.setattr(repl_mode.special, 'is_timing_enabled', lambda: True) + monkeypatch.setattr(repl_mode, 'is_select', lambda status: status == 'select') + monkeypatch.setattr(repl_mode.time, 'time', lambda: 0.0) + + repl_mode._output_results( + cli, + repl_mode.ReplState(), + sqlresult_generator( + SQLResult(status='watch', command={'name': 'watch', 'seconds': '1'}), + SQLResult(status='watch', command={'name': 'watch', 'seconds': '2'}), + SQLResult(status='select', rows=cast(Any, FakeCursorBase(rowcount=1001, warning_count=1))), + ), + start=0.0, + ) + assert cli.output_calls diff --git a/test/pytests/test_main_regression.py b/test/pytests/test_main_regression.py new file mode 100644 index 00000000..1712115a --- /dev/null +++ b/test/pytests/test_main_regression.py @@ -0,0 +1,1502 @@ +""" +These generated regression tests against main.py may be more brittle than +the primary tests in test_main.py. + +In addition, the tests in this file may enforce contracts that need not be +kept if main.py is refactored. + +Therefore authors should be free about + + * migrating individual tests if content moves out of main.py + * migrating individual tests to test_main.py after assessment of quality + * removing and rewriting these tests if contracts change +""" + +from __future__ import annotations + +import builtins +from collections.abc import Generator +import importlib.util +from io import StringIO +import itertools +import os +from pathlib import Path +import shutil +import sys +from types import ModuleType, SimpleNamespace +from typing import Any, cast + +import click +from click.testing import CliRunner +from configobj import ConfigObj +import prompt_toolkit +from prompt_toolkit.formatted_text import ( + ANSI, + FormattedText, +) +import pymysql +import pytest + +from mycli import main +from mycli.cli_args import IntOrStringClickParamType +import mycli.key_bindings +import mycli.output as output_module +from mycli.packages.sqlresult import SQLResult +from test.utils import ( # type: ignore[attr-defined] + DummyFormatter, + DummyLogger, + FakeCursorBase, + RecordingSQLExecute, + call_click_entrypoint_direct, + make_bare_mycli, + make_dummy_mycli_class, +) + + +class FakeConnection: + def __init__(self, ping_exc: Exception | None = None) -> None: + self.ping_exc = ping_exc + self.ping_calls: list[bool] = [] + + def ping(self, reconnect: bool = False) -> None: + self.ping_calls.append(reconnect) + if self.ping_exc is not None: + raise self.ping_exc + + +class BoolSection(dict[str, Any]): + def as_bool(self, key: str) -> bool: + return str(self[key]).lower() == 'true' + + +class ToggleBool: + def __init__(self, values: list[bool]) -> None: + self.values = values + + def __bool__(self) -> bool: + if self.values: + return self.values.pop(0) + return False + + +class IntRaises: + def __bool__(self) -> bool: + return True + + def __int__(self) -> int: + raise ValueError('bad int') + + +def load_main_variant(monkeypatch: pytest.MonkeyPatch, *, fail_pwd: bool = False) -> ModuleType: + import builtins + + original_import = builtins.__import__ + + def fake_import(name: str, globals: Any = None, locals: Any = None, fromlist: Any = (), level: int = 0) -> Any: # noqa: A002 + if fail_pwd and name == 'pwd': + raise ImportError('no pwd') + return original_import(name, globals, locals, fromlist, level) + + monkeypatch.setattr(builtins, '__import__', fake_import) + module_name = f'mycli_main_variant_{int(fail_pwd)}' + spec = importlib.util.spec_from_file_location(module_name, Path(main.__file__)) + assert spec is not None + assert spec.loader is not None + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + return module + + +def test_import_fallbacks_for_pwd(monkeypatch: pytest.MonkeyPatch) -> None: + module = load_main_variant(monkeypatch, fail_pwd=True) + + assert module.Query('sql', True, False).query == 'sql' + + +def test_register_special_commands_registers_expected_handlers(monkeypatch: pytest.MonkeyPatch) -> None: + cli = make_bare_mycli() + registered: list[tuple[Any, ...]] = [] + monkeypatch.setattr(main.special, 'register_special_command', lambda *args, **kwargs: registered.append(args)) + main.MyCli.register_special_commands(cli) + names = [args[1] for args in registered] + assert names == [ + 'use', + 'connect', + 'rehash', + 'tableformat', + 'redirectformat', + 'source', + 'prompt', + ] + + +def test_mycli_init_covers_config_warning_audit_log_and_login_path_errors(monkeypatch: pytest.MonkeyPatch) -> None: + class TypedSection(dict[str, Any]): + def as_bool(self, key: str) -> bool: + return str(self[key]).lower() == 'true' + + def as_float(self, key: str) -> float: + return float(self[key]) + + def as_int(self, key: str) -> int: + return int(self[key]) + + class TypedConfig(dict[str, Any]): + def __init__(self) -> None: + super().__init__({ + 'main': TypedSection({ + 'multi_line': 'false', + 'key_bindings': 'emacs', + 'timing': 'false', + 'show_favorite_query': 'false', + 'beep_after_seconds': '0', + 'table_format': 'ascii', + 'redirect_format': 'csv', + 'syntax_style': 'native', + 'less_chatty': 'true', + 'wider_completion_menu': 'false', + 'destructive_warning': 'false', + 'login_path_as_host': 'false', + 'post_redirect_command': '', + 'null_string': '', + 'numeric_alignment': 'right', + 'binary_display': '', + 'ssl_mode': 'bogus', + 'auto_vertical_output': 'false', + 'audit_log': '/tmp/audit.log', + 'smart_completion': 'false', + 'min_completion_trigger': '2', + 'prompt': '', + 'prompt_continuation': '>', + 'toolbar': 'default', + 'terminal_tab_title': '', + 'terminal_window_title': '', + 'multiplex_window_title': '', + 'multiplex_pane_title': '', + 'show_warnings': 'false', + }), + 'connection': TypedSection({'default_keepalive_ticks': '5', 'default_ssl_mode': None}), + 'keys': TypedSection({'emacs_ttimeoutlen': '1.0', 'vi_ttimeoutlen': '1.0'}), + 'colors': {}, + 'search': TypedSection({'highlight_preview': 'false'}), + 'llm': TypedSection({'prompt_field_truncate': '12', 'prompt_section_truncate': '34'}), + }) + self.filename = '/tmp/custom.rc' + + read_calls: list[tuple[bool, bool]] = [] + + def fake_read_config_files( + files: Any, ignore_package_defaults: bool = False, ignore_user_options: bool = False, **kwargs: Any + ) -> TypedConfig: + read_calls.append((ignore_package_defaults, ignore_user_options)) + return TypedConfig() + + write_default_calls: list[str] = [] + secho_calls: list[str] = [] + printed: list[str] = [] + monkeypatch.setattr(main, 'read_config_files', fake_read_config_files) + monkeypatch.setattr(main.special, 'set_timing_enabled', lambda enabled: None) + monkeypatch.setattr(main.special, 'set_show_favorite_query', lambda enabled: None) + monkeypatch.setattr(main, 'TabularOutputFormatter', lambda format_name: DummyFormatter(format_name)) + monkeypatch.setattr(main.sql_format, 'register_new_formatter', lambda formatter: None) + monkeypatch.setattr(main, 'style_factory_ptoolkit', lambda *args, **kwargs: 'style') + monkeypatch.setattr(main, 'style_factory_helpers', lambda *args, **kwargs: 'helpers') + monkeypatch.setattr(main.FavoriteQueries, 'from_config', classmethod(lambda cls, config: object())) + monkeypatch.setattr(main, 'CompletionRefresher', lambda: 'refresher') + monkeypatch.setattr(main, 'SQLCompleter', lambda *args, **kwargs: 'completer') + monkeypatch.setattr(main, 'write_default_config', lambda path: write_default_calls.append(path)) + monkeypatch.setattr(main, 'get_mylogin_cnf_path', lambda: '/tmp/mylogin.cnf') + monkeypatch.setattr(main, 'open_mylogin_cnf', lambda path: None) + monkeypatch.setattr(main.MyCli, 'register_special_commands', lambda self: None) + monkeypatch.setattr(main.MyCli, 'initialize_logging', lambda self: None) + monkeypatch.setattr(main.MyCli, 'read_my_cnf', lambda self, cnf, keys: {'prompt': None}) + monkeypatch.setattr(main.os.path, 'exists', lambda path: False) + monkeypatch.setattr(click, 'secho', lambda message, **kwargs: secho_calls.append(str(message))) + monkeypatch.setattr(builtins, 'print', lambda *args, **kwargs: printed.append(' '.join(str(x) for x in args))) + + def fake_open(path: Any, mode: str = 'r', *args: Any, **kwargs: Any) -> Any: + raise OSError('open failed') + + monkeypatch.setattr(builtins, 'open', fake_open) + mycli = main.MyCli(myclirc='/tmp/custom.rc') + assert mycli.llm_prompt_field_truncate == 12 + assert mycli.llm_prompt_section_truncate == 34 + assert mycli.ssl_mode is None + assert mycli.logfile is False + assert any('Invalid config option provided for ssl_mode' in msg for msg in secho_calls) + assert any('Unable to open the audit log file' in msg for msg in secho_calls) + assert printed == ['Error: Unable to read login path file.'] + assert write_default_calls == ['/tmp/custom.rc'] + assert read_calls == [(False, False), (True, False), (False, True), (False, False)] + + +def test_mycli_init_defaults_file_valid_ssl_and_mylogin_append(monkeypatch: pytest.MonkeyPatch) -> None: + class TypedSection(dict[str, Any]): + def as_bool(self, key: str) -> bool: + return str(self[key]).lower() == 'true' + + def as_float(self, key: str) -> float: + return float(self[key]) + + def as_int(self, key: str) -> int: + return int(self[key]) + + class TypedConfig(dict[str, Any]): + def __init__(self) -> None: + super().__init__({ + 'main': TypedSection({ + 'multi_line': 'false', + 'key_bindings': 'emacs', + 'timing': 'false', + 'show_favorite_query': 'false', + 'beep_after_seconds': '0', + 'table_format': 'ascii', + 'redirect_format': 'csv', + 'syntax_style': 'native', + 'less_chatty': 'true', + 'wider_completion_menu': 'false', + 'destructive_warning': 'false', + 'login_path_as_host': 'false', + 'post_redirect_command': '', + 'null_string': '', + 'numeric_alignment': 'right', + 'binary_display': '', + 'ssl_mode': 'auto', + 'auto_vertical_output': 'false', + 'smart_completion': 'false', + 'min_completion_trigger': '1', + 'prompt': '', + 'prompt_continuation': '>', + 'toolbar': 'default', + 'terminal_tab_title': '', + 'terminal_window_title': '', + 'multiplex_window_title': '', + 'multiplex_pane_title': '', + 'show_warnings': 'false', + }), + 'connection': TypedSection({'default_keepalive_ticks': '1', 'default_ssl_mode': None}), + 'keys': TypedSection({'emacs_ttimeoutlen': '1.0', 'vi_ttimeoutlen': '1.0'}), + 'colors': {}, + 'search': TypedSection({'highlight_preview': 'false'}), + }) + self.filename = '/tmp/custom.rc' + + mylogin_cnf = StringIO('[client]\nuser = alice\n') + monkeypatch.setattr(main, 'read_config_files', lambda *args, **kwargs: TypedConfig()) + monkeypatch.setattr(main.special, 'set_timing_enabled', lambda enabled: None) + monkeypatch.setattr(main.special, 'set_show_favorite_query', lambda enabled: None) + monkeypatch.setattr(main, 'TabularOutputFormatter', lambda format_name: DummyFormatter(format_name)) + monkeypatch.setattr(main.sql_format, 'register_new_formatter', lambda formatter: None) + monkeypatch.setattr(main, 'style_factory_ptoolkit', lambda *args, **kwargs: 'style') + monkeypatch.setattr(main, 'style_factory_helpers', lambda *args, **kwargs: 'helpers') + monkeypatch.setattr(main.FavoriteQueries, 'from_config', classmethod(lambda cls, config: object())) + monkeypatch.setattr(main, 'CompletionRefresher', lambda: 'refresher') + monkeypatch.setattr(main, 'SQLCompleter', lambda *args, **kwargs: 'completer') + monkeypatch.setattr(main.MyCli, 'register_special_commands', lambda self: None) + monkeypatch.setattr(main.MyCli, 'initialize_logging', lambda self: None) + monkeypatch.setattr(main.MyCli, 'read_my_cnf', lambda self, cnf, keys: {'prompt': None}) + monkeypatch.setattr(main, 'get_mylogin_cnf_path', lambda: '/tmp/mylogin.cnf') + monkeypatch.setattr(main, 'open_mylogin_cnf', lambda path: mylogin_cnf) + monkeypatch.setattr(main.os.path, 'exists', lambda path: True) + monkeypatch.setattr(click, 'secho', lambda *args, **kwargs: None) + + mycli = main.MyCli(defaults_file='/tmp/defaults.cnf', myclirc='/tmp/custom.rc') + assert mycli.cnf_files[0] == '/tmp/defaults.cnf' + assert mycli.cnf_files[-1] is mylogin_cnf + assert mycli.ssl_mode == 'auto' + assert mycli.llm_prompt_field_truncate == 0 + assert mycli.llm_prompt_section_truncate == 0 + + +def test_int_or_string_click_param_type_accepts_and_rejects_values() -> None: + param_type = IntOrStringClickParamType() + + assert param_type.convert(1, None, None) == 1 + assert param_type.convert('pw', None, None) == 'pw' + assert param_type.convert(None, None, None) is None + with pytest.raises(click.BadParameter): + param_type.convert(1.5, None, None) + + +def test_change_format_methods_cover_success_and_value_error() -> None: + cli = make_bare_mycli() + + result = next(main.MyCli.change_table_format(cli, 'ascii')) + assert result.status == 'Changed table format to ascii' + + cli.main_formatter = SimpleNamespace( + supported_formats=['ascii', 'csv'], + __setattr__=object.__setattr__, + ) + + class BadFormatter: + supported_formats = ['ascii', 'csv'] + + @property + def format_name(self) -> str: + return 'ascii' + + @format_name.setter + def format_name(self, value: str) -> None: + raise ValueError() + + cli.main_formatter = BadFormatter() + result = next(main.MyCli.change_table_format(cli, 'bad')) + assert 'Allowed formats' in str(result.status) + + cli.redirect_formatter = BadFormatter() + result = next(main.MyCli.change_redirect_format(cli, 'bad')) + assert 'Redirect format bad not recognized' in str(result.status) + + cli.redirect_formatter = DummyFormatter() + result = next(main.MyCli.change_redirect_format(cli, 'csv')) + assert result.status == 'Changed redirect format to csv' + + +def test_manual_reconnect() -> None: + cli = make_bare_mycli() + cli.reconnect = lambda database='': False # type: ignore[assignment] + assert next(main.MyCli.manual_reconnect(cli)).status == 'Not connected' + + cli.reconnect = lambda database='': True # type: ignore[assignment] + empty = next(main.MyCli.manual_reconnect(cli)) + assert empty.status is None + + def fake_change_db(arg: str) -> Generator[SQLResult, None, None]: + yield SQLResult(status=f'db:{arg}') + + cli.change_db = fake_change_db # type: ignore[assignment] + changed = next(main.MyCli.manual_reconnect(cli, 'prod')) + assert changed.status == 'db:prod' + + +def test_change_db_handles_empty_same_new_and_backticks(monkeypatch: pytest.MonkeyPatch) -> None: + cli = make_bare_mycli() + secho_calls: list[tuple[tuple[Any, ...], dict[str, Any]]] = [] + monkeypatch.setattr(click, 'secho', lambda *args, **kwargs: secho_calls.append((args, kwargs))) + cli.sqlexecute = object.__new__(main.SQLExecute) + cli.sqlexecute.dbname = 'db1' + cli.sqlexecute.user = 'user1' + changed_to: list[str] = [] + cli.sqlexecute.change_db = lambda arg: changed_to.append(arg) # type: ignore[assignment] + titles_called = {'count': 0} + monkeypatch.setattr( + main, + 'set_all_external_titles', + lambda mycli: titles_called.__setitem__('count', titles_called['count'] + 1), + ) + + assert list(main.MyCli.change_db(cli, '')) == [] + assert secho_calls[0][0][0] == 'No database selected' + + same = next(main.MyCli.change_db(cli, 'db1')) + assert 'already connected' in str(same.status) + + cli.sqlexecute.dbname = 'db2' + new = next(main.MyCli.change_db(cli, '`db``name`')) + assert changed_to == ['db`name'] + assert 'now connected' in str(new.status) + assert titles_called['count'] == 2 + + +def test_execute_from_file(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + cli = make_bare_mycli() + + class FakeSQLExecute: + def run(self, query: str) -> list[SQLResult]: + return [SQLResult(status=query)] + + monkeypatch.setattr(main, 'SQLExecute', FakeSQLExecute) + cli.sqlexecute = cast(Any, FakeSQLExecute()) + cli.destructive_warning = True + cli.destructive_keywords = ['drop'] + + assert list(main.MyCli.execute_from_file(cli, ''))[0].status == 'Missing required argument: filename.' + + missing = list(main.MyCli.execute_from_file(cli, str(tmp_path / 'missing.sql'))) + assert 'No such file' in str(missing[0].status) + + sql_file = tmp_path / 'query.sql' + sql_file.write_text('drop table test;', encoding='utf-8') + monkeypatch.setattr(main, 'confirm_destructive_query', lambda keywords, query: False) + stopped = list(main.MyCli.execute_from_file(cli, str(sql_file))) + assert stopped[0].status == 'Wise choice. Command execution stopped.' + + cli.destructive_warning = False + ran = list(main.MyCli.execute_from_file(cli, str(sql_file))) + assert ran[0].status == 'drop table test;' + + +def test_initialize_logging_covers_none_bad_path_and_file_handler(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + cli = make_bare_mycli() + echo_calls: list[str] = [] + cli.echo = lambda message, **kwargs: echo_calls.append(message) # type: ignore[assignment] + cli.config = {'main': {'log_file': str(tmp_path / 'mycli.log'), 'log_level': 'NONE'}} + monkeypatch.setattr(main, 'dir_path_exists', lambda path: True) + main.MyCli.initialize_logging(cli) + + cli.config = {'main': {'log_file': str(tmp_path / 'missing' / 'mycli.log'), 'log_level': 'INFO'}} + monkeypatch.setattr(main, 'dir_path_exists', lambda path: False) + main.MyCli.initialize_logging(cli) + assert echo_calls[-1].startswith('Error: Unable to open the log file') + + cli.config = {'main': {'log_file': str(tmp_path / 'mycli.log'), 'log_level': 'INFO'}} + monkeypatch.setattr(main, 'dir_path_exists', lambda path: True) + main.MyCli.initialize_logging(cli) + + +def test_read_my_cnf_and_merge_ssl_with_cnf() -> None: + cli = make_bare_mycli() + cli.login_path = 'prod' + cli.defaults_suffix = '_suffix' + cnf = ConfigObj() + cnf['client'] = {'prompt': '"mysql>"', 'ssl-ca': '/tmp/ca.pem'} + cnf['mysqld'] = {'socket': "'/tmp/mysql.sock'", 'port': '3307'} + cnf['prod'] = {'user': '`alice`'} + cnf['client_suffix'] = {'prompt': "'alt>'"} + values = main.MyCli.read_my_cnf(cli, cnf, ['prompt', 'socket', 'port', 'user', 'ssl-ca']) + assert values['prompt'] == 'alt>' + assert values['default_socket'] == '/tmp/mysql.sock' + assert values['default_port'] == '3307' + assert values['user'] == '`alice`' + + merged = main.MyCli.merge_ssl_with_cnf(cli, {'mode': 'on'}, {'ssl-ca': '/tmp/ca.pem', 'ssl-verify-server-cert': 'true', 'other': 'x'}) + assert merged['mode'] == 'on' + assert merged['ca'] == '/tmp/ca.pem' + assert merged['check_hostname'] is True + + +def test_connect_covers_defaults_keyring_prompt_retries_and_errors(monkeypatch: pytest.MonkeyPatch) -> None: + cli = make_bare_mycli() + cli.my_cnf = {'client': {}, 'mysqld': {}} + cli.config_without_package_defaults = {'connection': {'default_ssl_ca_path': '/ssl/ca-path', 'default_local_infile': 'true'}} + cli.config = {'connection': {'default_ssl_ca_path': '/ssl/ca-path'}, 'main': {'default_character_set': 'utf8mb4'}} + echo_calls: list[tuple[tuple[Any, ...], dict[str, Any]]] = [] + cli.echo = lambda *args, **kwargs: echo_calls.append((args, kwargs)) # type: ignore[assignment] + logger = DummyLogger() + cli.logger = cast(Any, logger) + monkeypatch.setattr(main, 'WIN', True) + monkeypatch.setattr(main, 'SQLExecute', RecordingSQLExecute) + RecordingSQLExecute.calls = [] + RecordingSQLExecute.side_effects = [] + monkeypatch.setattr(main, 'guess_socket_location', lambda: '/tmp/mysql.sock') + monkeypatch.setattr(main, 'str_to_bool', lambda value: str(value).lower() == 'true') + monkeypatch.setattr(main.keyring, 'get_password', lambda *args: 'stored-pw') + set_password_calls: list[tuple[str, str, str]] = [] + monkeypatch.setattr(main.keyring, 'set_password', lambda domain, ident, password: set_password_calls.append((domain, ident, password))) + monkeypatch.setenv('USER', 'env-user') + + main.MyCli.connect(cli, host='', port='', ssl={'mode': 'on'}, use_keyring=True) + assert RecordingSQLExecute.calls[-1]['socket'] == '/tmp/mysql.sock' + assert RecordingSQLExecute.calls[-1]['character_set'] == 'utf8mb4' + assert RecordingSQLExecute.calls[-1]['ssl']['capath'] == '/ssl/ca-path' + assert RecordingSQLExecute.calls[-1]['password'] == 'stored-pw' + + prompt_calls: list[str] = [] + + def fake_prompt(message: str, **kwargs: Any) -> str: + prompt_calls.append(message) + return 'entered-pw' + + monkeypatch.setattr(click, 'prompt', fake_prompt) + RecordingSQLExecute.calls = [] + main.MyCli.connect( + cli, user='alice', passwd=main.EMPTY_PASSWORD_FLAG_SENTINEL, host='db', port=3307, ssl={'mode': 'on'}, use_keyring=True + ) + assert prompt_calls == ['Enter password for alice'] + assert set_password_calls[-1][2] == 'entered-pw' + + handshake_error = pymysql.OperationalError(main.HANDSHAKE_ERROR, 'ssl fail') + RecordingSQLExecute.side_effects = [handshake_error, None] + RecordingSQLExecute.calls = [] + main.MyCli.connect(cli, host='db', port=3307, ssl={'mode': 'auto'}) + assert RecordingSQLExecute.calls[0]['ssl']['mode'] == 'auto' + assert RecordingSQLExecute.calls[1]['ssl'] is None + + access_error = pymysql.OperationalError(main.ACCESS_DENIED_ERROR, 'denied') + RecordingSQLExecute.side_effects = [access_error, None] + RecordingSQLExecute.calls = [] + monkeypatch.setattr(click, 'prompt', lambda message, **kwargs: 'retry-pw') + main.MyCli.connect(cli, user='bob', passwd=None, host='db', port=3307) + assert RecordingSQLExecute.calls[1]['password'] == 'retry-pw' + + server_lost = pymysql.OperationalError(main.CR_SERVER_LOST, 'lost') + RecordingSQLExecute.side_effects = [server_lost] + with pytest.raises(SystemExit): + main.MyCli.connect(cli, host='db', port=3307) + assert any('Connection to server lost' in str(call[0][0]) for call in echo_calls) + + RecordingSQLExecute.side_effects = [] + with pytest.raises(ValueError): + main.MyCli.connect(cli, host='db', port='bad-port') + + +def test_connect_socket_owner_and_tcp_fallback(monkeypatch: pytest.MonkeyPatch) -> None: + cli = make_bare_mycli() + cli.my_cnf = {'client': {}, 'mysqld': {}} + cli.config_without_package_defaults = {'connection': {}} + cli.config = {'connection': {}, 'main': {}} + echo_calls: list[str] = [] + cli.echo = lambda message, **kwargs: echo_calls.append(str(message)) # type: ignore[assignment] + cli.logger = cast(Any, DummyLogger()) + monkeypatch.setattr(main, 'WIN', False) + monkeypatch.setattr(main, 'getpwuid', lambda uid: SimpleNamespace(pw_name='socket-owner')) + original_stat = os.stat + + def fake_stat(path: Any, *args: Any, **kwargs: Any) -> os.stat_result: + if str(path) == '/tmp/mysql.sock': + return os.stat_result((0, 0, 0, 0, 123, 0, 0, 0, 0, 0)) + return original_stat(path, *args, **kwargs) + + monkeypatch.setattr(main.os, 'stat', fake_stat) + monkeypatch.setattr(main, 'str_to_bool', lambda value: False) + + class SocketThenTcpSQLExecute(RecordingSQLExecute): + calls: list[dict[str, Any]] = [] + side_effects: list[Any] = [pymysql.OperationalError(2002, 'socket fail'), None] + + monkeypatch.setattr(main, 'SQLExecute', SocketThenTcpSQLExecute) + main.MyCli.connect(cli, host='', port='', socket='/tmp/mysql.sock', ssl={'mode': 'on'}) + + assert 'Connecting to socket /tmp/mysql.sock, owned by user socket-owner' in echo_calls[0] + assert 'Retrying over TCP/IP' in echo_calls[-1] + assert len(SocketThenTcpSQLExecute.calls) == 2 + + +def test_connect_additional_error_and_config_branches(monkeypatch: pytest.MonkeyPatch) -> None: + cli = make_bare_mycli() + cli.config = {'connection': {'default_ssl_ca_path': '/tmp/ca-path'}, 'main': {}} + cli.config_without_package_defaults = {'connection': {}} + cli.my_cnf = {'client': {}, 'mysqld': {}} + cli.logger = cast(Any, DummyLogger()) + echo_calls: list[str] = [] + cli.echo = lambda message, **kwargs: echo_calls.append(str(message)) # type: ignore[assignment] + monkeypatch.setattr(main, 'WIN', False) + monkeypatch.setattr(main, 'str_to_bool', lambda value: False) + + def fake_read_my_cnf(cnf: Any, keys: list[str]) -> dict[str, Any]: + return { + 'database': None, + 'user': None, + 'password': None, + 'host': None, + 'port': None, + 'socket': None, + 'default_socket': None, + 'default-character-set': 'latin1', + 'local_infile': None, + 'local-infile': None, + 'loose_local_infile': None, + 'loose-local-infile': None, + 'ssl-ca': None, + 'ssl-cert': None, + 'ssl-key': None, + 'ssl-cipher': None, + 'ssl-verify-server-cert': None, + } + + cli.read_my_cnf = fake_read_my_cnf # type: ignore[assignment] + + class SuccessfulSQLExecute(RecordingSQLExecute): + calls: list[dict[str, Any]] = [] + side_effects: list[Any] = [] + + monkeypatch.setattr(main, 'SQLExecute', SuccessfulSQLExecute) + monkeypatch.setattr(main, 'getpwuid', lambda uid: (_ for _ in ()).throw(KeyError())) + original_stat = os.stat + + def fake_stat(path: Any, *args: Any, **kwargs: Any) -> os.stat_result: + if str(path) == '/tmp/mysql.sock': + return os.stat_result((0, 0, 0, 0, 123, 0, 0, 0, 0, 0)) + return original_stat(path, *args, **kwargs) + + monkeypatch.setattr(main.os, 'stat', fake_stat) + main.MyCli.connect(cli, host='', port='', socket='/tmp/mysql.sock', ssl={'mode': 'on'}) + assert 'owned by user ' in echo_calls[0] + assert SuccessfulSQLExecute.calls[-1]['character_set'] == 'latin1' + assert SuccessfulSQLExecute.calls[-1]['ssl']['capath'] == '/tmp/ca-path' + + with pytest.raises(ValueError): + main.MyCli.connect(cli, host='db.example', port='not-a-port') + + class UnexpectedSocketErrorSQLExecute(RecordingSQLExecute): + calls: list[dict[str, Any]] = [] + side_effects: list[Any] = [pymysql.OperationalError(9999, 'boom')] + + monkeypatch.setattr(main, 'SQLExecute', UnexpectedSocketErrorSQLExecute) + with pytest.raises(SystemExit): + main.MyCli.connect(cli, host='', port='', socket='/tmp/mysql.sock') + + +def test_connect_ssl_overrides_and_retry_password_exhausted(monkeypatch: pytest.MonkeyPatch) -> None: + cli = make_bare_mycli() + cli.config = {'connection': {'default_character_set': 'utf8mb4'}, 'main': {}} + cli.config_without_package_defaults = { + 'connection': { + 'default_local_infile': IntRaises(), + 'default_ssl_ca': '/tmp/ca.pem', + 'default_ssl_cert': '/tmp/cert.pem', + 'default_ssl_key': '/tmp/key.pem', + 'default_ssl_cipher': 'AES256', + 'default_ssl_verify_server_cert': 'true', + } + } + cli.my_cnf = {'client': {}, 'mysqld': {}} + cli.logger = cast(Any, DummyLogger()) + cli.echo = lambda *args, **kwargs: None # type: ignore[assignment] + + def fake_read_my_cnf(cnf: Any, keys: list[str]) -> dict[str, Any]: + return { + 'database': None, + 'user': None, + 'password': None, + 'host': None, + 'port': None, + 'socket': None, + 'default_socket': None, + 'default-character-set': None, + 'local_infile': None, + 'local-infile': None, + 'loose_local_infile': None, + 'loose-local-infile': None, + 'ssl-ca': None, + 'ssl-cert': None, + 'ssl-key': None, + 'ssl-cipher': None, + 'ssl-verify-server-cert': None, + } + + cli.read_my_cnf = fake_read_my_cnf # type: ignore[assignment] + + def fake_str_to_bool(value: Any) -> bool: + if isinstance(value, IntRaises): + raise ValueError('bad bool') + return str(value).lower() == 'true' + + monkeypatch.setattr(main, 'str_to_bool', fake_str_to_bool) + monkeypatch.setattr(main, 'SQLExecute', RecordingSQLExecute) + RecordingSQLExecute.calls = [] + RecordingSQLExecute.side_effects = [] + main.MyCli.connect(cli, host='db', port=3307, local_infile=cast(Any, IntRaises()), ssl={'mode': 'on'}) + ssl = RecordingSQLExecute.calls[-1]['ssl'] + assert ssl['ca'] == '/tmp/ca.pem' + assert ssl['cert'] == '/tmp/cert.pem' + assert ssl['key'] == '/tmp/key.pem' + assert ssl['cipher'] == 'AES256' + assert ssl['check_hostname'] is True + assert RecordingSQLExecute.calls[-1]['character_set'] == 'utf8mb4' + + access_error = pymysql.OperationalError(main.ACCESS_DENIED_ERROR, 'denied') + RecordingSQLExecute.calls = [] + RecordingSQLExecute.side_effects = [access_error, access_error] + monkeypatch.setattr(click, 'prompt', lambda *args, **kwargs: None) + with pytest.raises(SystemExit): + main.MyCli.connect(cli, user='bob', passwd=None, host='db', port=3307) + + +def test_connect_retries_ssl_password_and_handles_keyring_save_failure(monkeypatch: pytest.MonkeyPatch) -> None: + cli = make_bare_mycli() + cli.config = {'connection': {}, 'main': {}} + cli.config_without_package_defaults = {'connection': {}} + cli.my_cnf = {'client': {}, 'mysqld': {}} + cli.logger = cast(Any, DummyLogger()) + cli.echo = lambda *args, **kwargs: None # type: ignore[assignment] + + def read_my_cnf_all_none(cnf: Any, keys: list[str]) -> dict[str, Any]: + values = dict.fromkeys(keys) + values['local_infile'] = None + values['loose_local_infile'] = None + values['default_character_set'] = None + return values + + cli.read_my_cnf = read_my_cnf_all_none # type: ignore[assignment] + monkeypatch.setattr(main, 'WIN', False) + monkeypatch.setattr(main, 'str_to_bool', lambda value: False) + + class HandshakeRetrySQLExecute(RecordingSQLExecute): + calls: list[dict[str, Any]] = [] + side_effects: list[Any] = [ + pymysql.OperationalError(main.HANDSHAKE_ERROR, 'ssl fail'), + pymysql.OperationalError(main.HANDSHAKE_ERROR, 'ssl fail'), + ] + + monkeypatch.setattr(main, 'SQLExecute', HandshakeRetrySQLExecute) + with pytest.raises(SystemExit): + main.MyCli.connect(cli, host='db.example', ssl={'mode': 'auto'}) + assert HandshakeRetrySQLExecute.calls[0]['ssl'] == {'mode': 'auto'} + assert HandshakeRetrySQLExecute.calls[1]['ssl'] is None + + class PasswordRetrySQLExecute(RecordingSQLExecute): + calls: list[dict[str, Any]] = [] + side_effects: list[Any] = [ + pymysql.OperationalError(main.ACCESS_DENIED_ERROR, 'denied'), + pymysql.OperationalError(main.ACCESS_DENIED_ERROR, 'denied'), + ] + + monkeypatch.setattr(main, 'SQLExecute', PasswordRetrySQLExecute) + monkeypatch.setattr(click, 'prompt', lambda *args, **kwargs: 'new-password') + with pytest.raises(SystemExit): + main.MyCli.connect(cli, host='db.example', passwd=None) + assert PasswordRetrySQLExecute.calls[1]['password'] == 'new-password' + + class KeyringSaveSQLExecute(RecordingSQLExecute): + calls: list[dict[str, Any]] = [] + side_effects: list[Any] = [] + + saved_errors: list[str] = [] + monkeypatch.setattr(main, 'SQLExecute', KeyringSaveSQLExecute) + monkeypatch.setattr(main.keyring, 'get_password', lambda domain, ident: 'old-password') + monkeypatch.setattr(main.keyring, 'set_password', lambda domain, ident, password: (_ for _ in ()).throw(RuntimeError('no keyring'))) + monkeypatch.setattr(click, 'secho', lambda message, **kwargs: saved_errors.append(str(message))) + main.MyCli.connect(cli, host='db.example', passwd='new-password', use_keyring=True, reset_keyring=True) + assert any('Password not saved to the system keyring' in message for message in saved_errors) + + +def test_connect_covers_default_ssl_ca_path_and_late_invalid_port(monkeypatch: pytest.MonkeyPatch) -> None: + cli = make_bare_mycli() + cli.config = {'connection': {'default_ssl_ca_path': '/tmp/ca-path'}, 'main': {}} + cli.config_without_package_defaults = {'connection': {}} + cli.my_cnf = {'client': {}, 'mysqld': {}} + cli.logger = cast(Any, DummyLogger()) + echo_calls: list[str] = [] + cli.echo = lambda message, **kwargs: echo_calls.append(str(message)) # type: ignore[assignment] + cli.read_my_cnf = lambda cnf, keys: dict.fromkeys(keys) | {'local_infile': None, 'loose_local_infile': None} + monkeypatch.setattr(main, 'WIN', False) + monkeypatch.setattr(main, 'guess_socket_location', lambda: '') + monkeypatch.setattr(main, 'str_to_bool', lambda value: False) + monkeypatch.setattr(main.MyCli, 'merge_ssl_with_cnf', lambda self, ssl, cnf: None) + + class CaptureSQLExecute(RecordingSQLExecute): + calls: list[dict[str, Any]] = [] + side_effects: list[Any] = [] + + monkeypatch.setattr(main, 'SQLExecute', CaptureSQLExecute) + main.MyCli.connect(cli, host='', port='', socket='') + assert CaptureSQLExecute.calls[-1]['ssl'] is None + + class PortValue(ToggleBool): + def __init__(self) -> None: + super().__init__([False, False, True]) + + def __int__(self) -> int: + raise ValueError('bad port') + + cli.read_my_cnf = lambda cnf, keys: ( + dict.fromkeys(keys) | {'port': cast(Any, PortValue()), 'local_infile': None, 'loose_local_infile': None} + ) # noqa: C420 + with pytest.raises(SystemExit): + main.MyCli.connect(cli, host='db.example', port='', socket='') + assert any('Invalid port number' in msg for msg in echo_calls) + + +def test_reconnect_logging_and_output(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: + cli = make_bare_mycli() + sqlexecute = object.__new__(main.SQLExecute) + + class ThirdPassConnection: + def __init__(self) -> None: + self.select_db_calls: list[str] = [] + + def ping(self, reconnect: bool = False) -> None: + raise pymysql.err.Error() + + def select_db(self, dbname: str) -> None: + self.select_db_calls.append(dbname) + + conn = ThirdPassConnection() + sqlexecute.conn = cast(Any, conn) + sqlexecute.dbname = 'prod' + sqlexecute.connection_id = 10 + + def fake_reset_connection_id() -> None: + return None + + def fake_connect() -> None: + return None + + sqlexecute.reset_connection_id = fake_reset_connection_id # type: ignore[assignment] + sqlexecute.connect = fake_connect # type: ignore[assignment] + cli.sqlexecute = cast(Any, sqlexecute) + echoes: list[str] = [] + cli.echo = lambda message, **kwargs: echoes.append(str(message)) # type: ignore[assignment] + assert main.MyCli.reconnect(cli) is True + assert 'Creating new connection...' in echoes + assert 'Any session state was reset.' in echoes + + def failing_connect() -> None: + raise pymysql.OperationalError(2000, 'still down') + + sqlexecute.connect = failing_connect # type: ignore[assignment] + assert main.MyCli.reconnect(cli) is False + assert 'still down' in echoes[-1] + + logfile = tmp_path / 'audit.log' + with logfile.open('w+', encoding='utf-8') as handle: + cli.logfile = handle + main.MyCli.log_query(cli, 'select 1') + main.MyCli.log_output(cli, ANSI('\x1b[31mhello\x1b[0m')) + handle.seek(0) + contents = handle.read() + assert 'select 1' in contents + assert 'hello' in contents + + printed_status: list[Any] = [] + echoed_lines: list[str] = [] + monkeypatch.setattr(main.special, 'is_redirected', lambda: True) + monkeypatch.setattr(main.special, 'write_tee', lambda text: None) + monkeypatch.setattr(main.special, 'write_once', lambda text: None) + monkeypatch.setattr(main.special, 'write_pipe_once', lambda text: None) + monkeypatch.setattr(main.special, 'is_pager_enabled', lambda: False) + monkeypatch.setattr(main.MyCli, 'get_output_margin', lambda self, status=None: 1) + monkeypatch.setattr(click, 'secho', lambda line, **kwargs: echoed_lines.append(str(line))) + monkeypatch.setattr(prompt_toolkit, 'print_formatted_text', lambda text, style=None: printed_status.append((text, style))) + main.MyCli.output(cli, itertools.chain(['row 1']), SQLResult(status='status')) + assert echoed_lines == [] + assert printed_status + + +def test_reconnect_first_and_second_passes(monkeypatch: pytest.MonkeyPatch) -> None: + cli = make_bare_mycli() + echoes: list[str] = [] + cli.echo = lambda message, **kwargs: echoes.append(str(message)) # type: ignore[assignment] + + class FirstPassConnection: + def ping(self, reconnect: bool = False) -> None: + return None + + sqlexecute = object.__new__(main.SQLExecute) + sqlexecute.conn = cast(Any, FirstPassConnection()) + sqlexecute.dbname = 'db' + sqlexecute.connection_id = 1 + cli.sqlexecute = cast(Any, sqlexecute) + assert main.MyCli.reconnect(cli) is True + assert 'Already connected.' in echoes + + class SecondPassConnection: + def __init__(self) -> None: + self.calls: list[bool] = [] + self.selected: list[str] = [] + + def ping(self, reconnect: bool = False) -> None: + self.calls.append(reconnect) + if not reconnect: + raise pymysql.err.Error() + + def select_db(self, dbname: str) -> None: + self.selected.append(dbname) + + second_conn = SecondPassConnection() + sqlexecute.conn = cast(Any, second_conn) + sqlexecute.connection_id = 10 + + def fake_reset_connection_id() -> None: + sqlexecute.connection_id = 11 + + sqlexecute.reset_connection_id = fake_reset_connection_id # type: ignore[assignment] + assert main.MyCli.reconnect(cli, database='prod') is True + assert second_conn.calls == [False, True] + assert second_conn.selected == ['db'] + assert 'Reconnected successfully.' in echoes + + +def test_format_sqlresult_string_paths_and_close() -> None: + cli = make_bare_mycli() + closed: list[bool] = [] + cli.sqlexecute = cast(Any, SimpleNamespace(close=lambda: closed.append(True))) + main.MyCli.close(cli) + assert closed == [True] + + class StringFormatter(DummyFormatter): + def format_output(self, rows: Any, header: Any, format_name: str | None = None, **kwargs: Any) -> str: + if format_name == 'vertical': + return 'vertical-a\nvertical-b' + return 'short\nsecond' + + cli.main_formatter = StringFormatter() + cli.redirect_formatter = StringFormatter() + result = SQLResult(header=['id'], rows=[(1,)], status='ok') + assert list(main.MyCli.format_sqlresult(cli, result)) == ['short', 'second'] + assert list(main.MyCli.format_sqlresult(cli, result, max_width=10)) == ['short', 'second'] + assert list(main.MyCli.format_sqlresult(cli, result, max_width=2)) == ['vertical-a', 'vertical-b'] + + +def test_output_uses_stdout_and_pager_paths(monkeypatch: pytest.MonkeyPatch) -> None: + cli = make_bare_mycli() + cli.explicit_pager = False + cli.prompt_lines = 1 + cli.prompt_session = None + cli.log_output = lambda text: None # type: ignore[assignment] + monkeypatch.setattr(main.special, 'write_tee', lambda text: None) + monkeypatch.setattr(main.special, 'write_once', lambda text: None) + monkeypatch.setattr(main.special, 'write_pipe_once', lambda text: None) + monkeypatch.setattr(main.special, 'is_redirected', lambda: False) + pager_enabled = {'value': False} + monkeypatch.setattr(main.special, 'is_pager_enabled', lambda: pager_enabled['value']) + monkeypatch.setattr(main.MyCli, 'get_output_margin', lambda self, status=None: 1) + printed_lines: list[str] = [] + paged_lines: list[str] = [] + monkeypatch.setattr(click, 'secho', lambda line, **kwargs: printed_lines.append(str(line))) + monkeypatch.setattr(click, 'echo_via_pager', lambda gen: paged_lines.extend(list(gen))) + monkeypatch.setattr(prompt_toolkit, 'print_formatted_text', lambda text, style=None: None) + + main.MyCli.output(cli, itertools.chain(['a' * 81, 'tail']), SQLResult(status='ok')) + assert printed_lines[:2] == ['a' * 81, 'tail'] + + printed_lines.clear() + pager_enabled['value'] = True + cli.explicit_pager = True + main.MyCli.output(cli, itertools.chain(['row1', 'row2']), SQLResult(status='ok')) + assert paged_lines[-2:] == ['row1\n', 'row2\n'] + + +def test_format_sqlresult_output_covers_extra_branches(monkeypatch: pytest.MonkeyPatch) -> None: + cli = make_bare_mycli() + cli.main_formatter = DummyFormatter() + cli.redirect_formatter = DummyFormatter() + cli.get_reserved_space = lambda: 1 # type: ignore[assignment] + monkeypatch.setattr(output_module, 'Cursor', FakeCursorBase) + rows = FakeCursorBase(rows=[], rowcount=0, description=[('id', 3, None, None, None, None, None)]) + result = SQLResult( + header=['id'], + rows=cast(Any, rows), + preamble='preamble', + status=FormattedText([('', 'formatted-status')]), + ) + formatted = list(main.MyCli.format_sqlresult(cli, result, null_string='NULL')) + assert 'preamble' in formatted + _, kwargs = cli.main_formatter.calls[-1] + assert kwargs['missing_value'] == 'NULL' + assert kwargs['column_types'] == [] + assert kwargs['colalign'] == [] + + paged_lines: list[str] = [] + printed_lines: list[str] = [] + status_prints: list[Any] = [] + monkeypatch.setattr(main.special, 'write_tee', lambda text: None) + monkeypatch.setattr(main.special, 'write_once', lambda text: None) + monkeypatch.setattr(main.special, 'write_pipe_once', lambda text: None) + monkeypatch.setattr(main.special, 'is_redirected', lambda: False) + monkeypatch.setattr(main.special, 'is_pager_enabled', lambda: True) + monkeypatch.setattr(main.MyCli, 'get_output_margin', lambda self, status=None: 1) + monkeypatch.setattr(click, 'echo_via_pager', lambda gen: paged_lines.extend(list(gen))) + monkeypatch.setattr(click, 'secho', lambda line, **kwargs: printed_lines.append(str(line))) + monkeypatch.setattr(prompt_toolkit, 'print_formatted_text', lambda text, style=None: status_prints.append(text)) + cli.log_output = lambda text: None # type: ignore[assignment] + cli.explicit_pager = False + main.MyCli.output(cli, itertools.chain(['x' * 81]), result) + assert paged_lines[-1] == ('x' * 81) + '\n' + monkeypatch.setattr(main.special, 'is_pager_enabled', lambda: False) + main.MyCli.output(cli, itertools.chain(['short']), result) + assert printed_lines[-1] == 'short' + assert status_prints + + +def test_main_handles_click_exception_without_exit_code(monkeypatch: pytest.MonkeyPatch) -> None: + class NoExitCode(click.ClickException): + def __getattribute__(self, name: str) -> Any: + if name == 'exit_code': + raise AttributeError(name) + return super().__getattribute__(name) + + monkeypatch.setattr(main, 'filtered_sys_argv', lambda: ['--help']) + monkeypatch.setattr(main.click_entrypoint, 'main', lambda *args, **kwargs: (_ for _ in ()).throw(NoExitCode('boom'))) + with pytest.raises(SystemExit) as excinfo: + main.main() + assert excinfo.value.code == 2 + + +def test_filtered_sys_argv_covers_help_and_passthrough(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(main.sys, 'argv', ['mycli', '-h']) + assert main.filtered_sys_argv() == ['--help'] + monkeypatch.setattr(main.sys, 'argv', ['mycli', '-h', 'db.example']) + assert main.filtered_sys_argv() == ['-h', 'db.example'] + + +def test_main_wrapper_and_edit_and_execute(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(main, 'filtered_sys_argv', lambda: ['--help']) + monkeypatch.setattr(main.click_entrypoint, 'main', lambda *args, **kwargs: None) + assert main.main() == 0 + + monkeypatch.setattr(main.click_entrypoint, 'main', lambda *args, **kwargs: 7) + assert main.main() == 7 + + monkeypatch.setattr(main.click_entrypoint, 'main', lambda *args, **kwargs: 'bad') + assert main.main() == 1 + + monkeypatch.setattr(main.click_entrypoint, 'main', lambda *args, **kwargs: (_ for _ in ()).throw(click.Abort())) + with pytest.raises(SystemExit): + main.main() + + monkeypatch.setattr(main.click_entrypoint, 'main', lambda *args, **kwargs: (_ for _ in ()).throw(BrokenPipeError())) + with pytest.raises(SystemExit): + main.main() + + class ErrorWithCode(click.ClickException): + exit_code = 9 + + monkeypatch.setattr(main.click_entrypoint, 'main', lambda *args, **kwargs: (_ for _ in ()).throw(ErrorWithCode('boom'))) + with pytest.raises(SystemExit): + main.main() + + class ErrorNoCode(click.ClickException): + pass + + monkeypatch.setattr(main.click_entrypoint, 'main', lambda *args, **kwargs: (_ for _ in ()).throw(ErrorNoCode('boom'))) + with pytest.raises(SystemExit): + main.main() + + opened: list[bool] = [] + event = cast( + Any, + SimpleNamespace( + current_buffer=SimpleNamespace(open_in_editor=lambda validate_and_handle=False: opened.append(validate_and_handle)) + ), + ) + mycli.key_bindings.edit_and_execute(event) + assert opened == [False] + + +def test_module_main_guard_calls_sys_exit(monkeypatch: pytest.MonkeyPatch) -> None: + exit_codes: list[int | None] = [] + monkeypatch.setattr(sys, 'exit', lambda code=0: exit_codes.append(code)) + monkeypatch.setattr(click.core.Command, 'main', lambda self, *args, **kwargs: 0) + original_main = sys.modules.get('__main__') + spec = importlib.util.spec_from_file_location('__main__', Path(main.__file__)) + assert spec is not None + assert spec.loader is not None + module = importlib.util.module_from_spec(spec) + sys.modules['__main__'] = module + try: + spec.loader.exec_module(module) + finally: + if original_main is not None: + sys.modules['__main__'] = original_main + assert exit_codes[-1] == 0 + + +def test_click_entrypoint_branches_with_dummy_mycli(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: + runner = CliRunner() + monkeypatch.setattr(main, 'MyCli', make_dummy_mycli_class()) + monkeypatch.setattr(main.sys, 'stdin', SimpleNamespace(isatty=lambda: True)) + monkeypatch.setattr(main.sys.stderr, 'isatty', lambda: True) + + checkup_calls: list[Any] = [] + monkeypatch.setattr(main, 'main_checkup', lambda mycli: checkup_calls.append(mycli)) + result = runner.invoke(main.click_entrypoint, ['--checkup']) + assert result.exit_code == 0 + assert len(checkup_calls) == 1 + + result = runner.invoke(main.click_entrypoint, ['--csv', '--format', 'table']) + assert result.exit_code == 1 + assert 'Conflicting --csv' in result.output + + result = runner.invoke(main.click_entrypoint, ['--table', '--format', 'csv']) + assert result.exit_code == 1 + assert 'Conflicting --table' in result.output + + monkeypatch.setattr(main, 'MyCli', make_dummy_mycli_class(config={'main': {}, 'alias_dsn': {'a': 'mysql://u:p@h/db'}})) + result = runner.invoke(main.click_entrypoint, ['--list-dsn']) + assert result.exit_code == 0 + assert 'a' in result.output + + monkeypatch.setattr(main, 'MyCli', make_dummy_mycli_class(config={'main': {}})) + result = runner.invoke(main.click_entrypoint, ['--list-dsn']) + assert result.exit_code == 1 + assert 'Invalid DSNs found' in result.output + + monkeypatch.setenv('MYSQL_UNIX_PORT', '/tmp/mysql.sock') + monkeypatch.setenv('DSN', 'mysql://user:pw@host/db') + monkeypatch.setattr(main, 'MyCli', make_dummy_mycli_class()) + result = runner.invoke(main.click_entrypoint, []) + assert result.exit_code == 0 + assert 'MYSQL_UNIX_PORT environment variable is deprecated' in result.output + assert 'DSN environment variable is deprecated' in result.output + + monkeypatch.delenv('MYSQL_UNIX_PORT', raising=False) + monkeypatch.delenv('DSN', raising=False) + monkeypatch.setattr(main, 'MyCli', make_dummy_mycli_class(config={'main': {}, 'alias_dsn': {}})) + result = runner.invoke(main.click_entrypoint, ['-d', 'missing-dsn']) + assert result.exit_code == 1 + assert 'Could not find the specified DSN' in result.output + + dummy_class = make_dummy_mycli_class( + config={ + 'main': {'use_keyring': 'false'}, + 'alias_dsn': { + 'prod': 'mysql://user:pw@host/db?ssl=true&ssl_ca=/tmp/ca.pem&socket=/tmp/mysql.sock&keepalive_ticks=9&character_set=utf8mb4' + }, + } + ) + monkeypatch.setattr(main, 'MyCli', dummy_class) + result = runner.invoke(main.click_entrypoint, ['-d', 'prod', '--ssl-mode', 'off', '--no-ssl']) + assert result.exit_code == 0 + dummy = dummy_class.last_instance + assert dummy is not None + connect_kwargs = dummy.connect_calls[-1] + assert connect_kwargs['database'] == 'db' + assert connect_kwargs['user'] == 'user' + assert connect_kwargs['passwd'] == 'pw' + assert connect_kwargs['socket'] == '/tmp/mysql.sock' + assert connect_kwargs['character_set'] == 'utf8mb4' + assert connect_kwargs['keepalive_ticks'] == 9 + + dummy_class = make_dummy_mycli_class(config={'main': {}, 'alias_dsn': {}}) + monkeypatch.setattr(main, 'MyCli', dummy_class) + monkeypatch.setattr(main.sys, 'stdin', SimpleNamespace(isatty=lambda: False)) + result = runner.invoke(main.click_entrypoint, ['--execute', 'select 1\\G', '--format', 'csv', '--batch', 'queries.sql']) + assert result.exit_code == 0 + dummy = dummy_class.last_instance + assert dummy is not None + assert dummy.main_formatter.format_name == 'csv' + assert dummy.run_query_calls[-1][0] == 'select 1' + + +def test_click_entrypoint_password_file_and_dsn_early_branches(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: + runner = CliRunner() + dummy_class = make_dummy_mycli_class(config={'main': {}, 'alias_dsn': {}, 'connection': {'default_keepalive_ticks': 0}}) + monkeypatch.setattr(main, 'MyCli', dummy_class) + monkeypatch.setattr(main.sys, 'stdin', SimpleNamespace(isatty=lambda: True)) + monkeypatch.setattr(main.sys.stderr, 'isatty', lambda: False) + + missing = runner.invoke(main.click_entrypoint, ['--password-file', str(tmp_path / 'missing.txt')]) + assert missing.exit_code == 1 + assert 'not found' in missing.output + + directory = runner.invoke(main.click_entrypoint, ['--password-file', str(tmp_path)]) + assert directory.exit_code == 1 + assert 'is a directory' in directory.output + + pw_file = tmp_path / 'pw.txt' + pw_file.write_text('from-file\n', encoding='utf-8') + result = runner.invoke(main.click_entrypoint, ['--password-file', str(pw_file)]) + assert result.exit_code == 0 + dummy = dummy_class.last_instance + assert dummy is not None + assert dummy.connect_calls[-1]['passwd'] == 'from-file' + + monkeypatch.setenv('MYSQL_PWD', 'envpass') + result = runner.invoke(main.click_entrypoint, []) + assert result.exit_code == 0 + dummy = dummy_class.last_instance + assert dummy is not None + assert dummy.connect_calls[-1]['passwd'] == 'envpass' + monkeypatch.delenv('MYSQL_PWD', raising=False) + + monkeypatch.setattr(main, 'is_valid_connection_scheme', lambda text: (False, 'bogus')) + result = runner.invoke(main.click_entrypoint, ['--password', 'bogus://dsn']) + assert result.exit_code == 1 + assert 'Unknown connection scheme' in result.output + + monkeypatch.setattr(main, 'is_valid_connection_scheme', lambda text: (True, 'mysql')) + result = runner.invoke(main.click_entrypoint, ['--password', 'mysql://dsn_user:dsn_pass@dsn_host/dsn_db']) + assert result.exit_code == 0 + dummy = dummy_class.last_instance + assert dummy is not None + assert dummy.connect_calls[-1]['database'] == 'dsn_db' + + +def test_click_entrypoint_list_and_dsn_option_branches(monkeypatch: pytest.MonkeyPatch) -> None: + runner = CliRunner() + + class ErrorConfig(dict[str, Any]): + def __getitem__(self, key: str) -> Any: + if key == 'alias_dsn': + raise RuntimeError('bad aliases') + return super().__getitem__(key) + + dummy_class = make_dummy_mycli_class(config=cast(Any, ErrorConfig({'main': {}}))) + monkeypatch.setattr(main, 'MyCli', dummy_class) + result = runner.invoke(main.click_entrypoint, ['--list-dsn']) + assert result.exit_code == 1 + assert 'bad aliases' in result.output + + dummy_class = make_dummy_mycli_class( + config={'main': {}, 'alias_dsn': {'prod': 'mysql://u:p@h/db'}, 'connection': {'default_keepalive_ticks': 0}} + ) + monkeypatch.setattr(main, 'MyCli', dummy_class) + result = runner.invoke(main.click_entrypoint, ['prod']) + assert result.exit_code == 0 + dummy = dummy_class.last_instance + assert dummy is not None + assert dummy.init_kwargs['myclirc'] == '~/.myclirc' + assert dummy.dsn_alias == 'prod' + + result = runner.invoke(main.click_entrypoint, ['mysql://u:p@h/db']) + assert result.exit_code == 0 + + result = runner.invoke(main.click_entrypoint, ['--dsn', 'mysql://u:p@h/db']) + assert result.exit_code == 0 + + +def test_click_entrypoint_callback_covers_password_file_permission_and_generic_errors(monkeypatch: pytest.MonkeyPatch) -> None: + dummy_class = make_dummy_mycli_class(config={'main': {}, 'alias_dsn': {}, 'connection': {'default_keepalive_ticks': 0}}) + monkeypatch.setattr(main, 'MyCli', dummy_class) + monkeypatch.setattr(main.sys, 'stdin', SimpleNamespace(isatty=lambda: True)) + monkeypatch.setattr(main.sys.stderr, 'isatty', lambda: False) + cli_args = main.CliArgs() + cli_args.password_file = '/tmp/secret' + + monkeypatch.setattr(builtins, 'open', lambda *args, **kwargs: (_ for _ in ()).throw(PermissionError())) + with pytest.raises(SystemExit): + call_click_entrypoint_direct(cli_args) + + monkeypatch.setattr(builtins, 'open', lambda *args, **kwargs: (_ for _ in ()).throw(RuntimeError('boom'))) + with pytest.raises(SystemExit): + call_click_entrypoint_direct(cli_args) + + +def test_click_entrypoint_callback_covers_nested_empty_password_file_guard(monkeypatch: pytest.MonkeyPatch) -> None: + class TogglePasswordFile: + def __init__(self) -> None: + self.calls = 0 + + def __bool__(self) -> bool: + self.calls += 1 + return self.calls == 1 + + dummy_class = make_dummy_mycli_class(config={'main': {}, 'alias_dsn': {}, 'connection': {'default_keepalive_ticks': 0}}) + monkeypatch.setattr(main, 'MyCli', dummy_class) + monkeypatch.setattr(main.sys, 'stdin', SimpleNamespace(isatty=lambda: True)) + monkeypatch.setattr(main.sys.stderr, 'isatty', lambda: False) + open_calls: list[tuple[tuple[Any, ...], dict[str, Any]]] = [] + + def fake_open(*args: Any, **kwargs: Any) -> None: + open_calls.append((args, kwargs)) + return None + + monkeypatch.setattr(builtins, 'open', fake_open) + cli_args = main.CliArgs() + cli_args.password_file = cast(Any, TogglePasswordFile()) + call_click_entrypoint_direct(cli_args) + + dummy = dummy_class.last_instance + assert dummy is not None + assert dummy.connect_calls[-1]['passwd'] is None + assert open_calls == [] + + +def test_click_entrypoint_callback_covers_dsn_params_init_commands_and_keyring(monkeypatch: pytest.MonkeyPatch) -> None: + dummy_class = make_dummy_mycli_class( + config={ + 'main': {'use_keyring': 'false', 'my_cnf_transition_done': 'true'}, + 'connection': {'default_keepalive_ticks': 2}, + 'alias_dsn': { + 'prod': ( + 'mysql://user:pw@db.example/prod_db' + '?ssl_mode=auto&ssl_ca=/tmp/ca.pem&ssl_capath=/tmp/capath' + '&ssl_cert=/tmp/cert.pem&ssl_key=/tmp/key.pem&ssl_cipher=AES256' + '&tls_version=TLSv1.2&ssl_verify_server_cert=true&socket=/tmp/mysql.sock' + '&keepalive_ticks=9&character_set=utf8mb4' + ) + }, + 'init-commands': {'a': 'set a=1', 'b': ['set b=2']}, + 'alias_dsn.init-commands': {'prod': 'set c=3'}, + }, + my_cnf={'client': {}, 'mysqld': {}}, + ) + monkeypatch.setattr(main, 'MyCli', dummy_class) + monkeypatch.setattr(main.sys, 'stdin', SimpleNamespace(isatty=lambda: True)) + monkeypatch.setattr(main.sys.stderr, 'isatty', lambda: False) + click_lines: list[str] = [] + monkeypatch.setattr(click, 'secho', lambda message='', **kwargs: click_lines.append(str(message))) + monkeypatch.setattr(click, 'echo', lambda message='', **kwargs: click_lines.append(str(message))) + + cli_args = main.CliArgs() + cli_args.database = 'prod' + cli_args.init_command = 'set e=5' + cli_args.use_keyring = 'reset' + call_click_entrypoint_direct(cli_args) + + dummy = dummy_class.last_instance + assert dummy is not None + connect_kwargs = dummy.connect_calls[-1] + assert connect_kwargs['database'] == 'prod_db' + assert connect_kwargs['user'] == 'user' + assert connect_kwargs['passwd'] == 'pw' + assert connect_kwargs['ssl'] is None + assert connect_kwargs['character_set'] == 'utf8mb4' + assert connect_kwargs['keepalive_ticks'] == 9 + assert connect_kwargs['use_keyring'] is True + assert connect_kwargs['reset_keyring'] is True + assert connect_kwargs['init_command'] == 'set a=1; set b=2; set c=3; set e=5' + assert any('Executing init-command:' in line for line in click_lines) + + +def test_click_entrypoint_callback_covers_database_dsn_and_verbose_lists(monkeypatch: pytest.MonkeyPatch) -> None: + click_lines: list[str] = [] + monkeypatch.setattr(click, 'secho', lambda message='', **kwargs: click_lines.append(str(message))) + monkeypatch.setattr(main.sys, 'stdin', SimpleNamespace(isatty=lambda: True)) + monkeypatch.setattr(main.sys.stderr, 'isatty', lambda: False) + + dummy_class = make_dummy_mycli_class( + config={ + 'main': {'use_keyring': 'false', 'my_cnf_transition_done': 'true'}, + 'connection': {'default_keepalive_ticks': 0}, + 'alias_dsn': {}, + } + ) + monkeypatch.setattr(main, 'MyCli', dummy_class) + cli_args = main.CliArgs() + cli_args.database = ( + 'mysql://dsn_user:dsn_pass@dsn_host/dsn_db' + '?ssl_capath=/tmp/capath&ssl_cert=/tmp/cert.pem&ssl_key=/tmp/key.pem' + '&ssl_cipher=AES256&tls_version=TLSv1.2&ssl_verify_server_cert=true' + ) + cli_args.use_keyring = 'false' + call_click_entrypoint_direct(cli_args) + dummy = dummy_class.last_instance + assert dummy is not None + connect_kwargs = dummy.connect_calls[-1] + assert connect_kwargs['database'] == 'dsn_db' + assert connect_kwargs['user'] == 'dsn_user' + assert connect_kwargs['passwd'] == 'dsn_pass' + assert connect_kwargs['host'] == 'dsn_host' + assert connect_kwargs['ssl']['capath'] == '/tmp/capath' + assert connect_kwargs['ssl']['cert'] == '/tmp/cert.pem' + assert connect_kwargs['ssl']['key'] == '/tmp/key.pem' + assert connect_kwargs['ssl']['cipher'] == 'AES256' + assert connect_kwargs['ssl']['tls_version'] == 'TLSv1.2' + assert connect_kwargs['ssl']['check_hostname'] is True + assert connect_kwargs['use_keyring'] is False + + +def test_click_entrypoint_callback_covers_misc_format_transition_and_execute_branches( + monkeypatch: pytest.MonkeyPatch, tmp_path: Path +) -> None: + click_lines: list[str] = [] + monkeypatch.setattr(click, 'secho', lambda message='', **kwargs: click_lines.append(str(message))) + monkeypatch.setattr(main.sys, 'stdin', SimpleNamespace(isatty=lambda: True)) + monkeypatch.setattr(main.sys.stderr, 'isatty', lambda: False) + + dummy_class = make_dummy_mycli_class( + config={ + 'main': {'use_keyring': 'false', 'my_cnf_transition_done': 'false'}, + 'connection': {'default_keepalive_ticks': 0}, + 'alias_dsn': {}, + }, + my_cnf={'client': {'prompt': 'mysql>'}, 'mysqld': {}}, + config_without_package_defaults={'main': {}}, + ) + monkeypatch.setattr(main, 'MyCli', dummy_class) + + pw_file = tmp_path / 'pw.txt' + pw_file.write_text('from-file\n', encoding='utf-8') + cli_args = main.CliArgs() + cli_args.password_file = str(pw_file) + call_click_entrypoint_direct(cli_args) + assert dummy_class.last_instance is not None + assert dummy_class.last_instance.connect_calls[-1]['passwd'] == 'from-file' + + cli_args = main.CliArgs() + cli_args.csv = True + call_click_entrypoint_direct(cli_args) + assert cli_args.format == 'csv' + + cli_args = main.CliArgs() + cli_args.table = True + call_click_entrypoint_direct(cli_args) + assert cli_args.format == 'table' + + assert any('Reading configuration from my.cnf files is deprecated.' in line for line in click_lines) + + execute_dummy_cls: type[Any] = make_dummy_mycli_class( + config={ + 'main': {'use_keyring': 'false', 'my_cnf_transition_done': 'true'}, + 'connection': {'default_keepalive_ticks': 0}, + 'alias_dsn': {}, + } + ) + monkeypatch.setattr(main, 'MyCli', execute_dummy_cls) + monkeypatch.setattr(main.sys, 'stdin', SimpleNamespace(isatty=lambda: False)) + + cli_args = main.CliArgs() + cli_args.execute = 'select 1\\G' + cli_args.format = 'tsv' + with pytest.raises(SystemExit): + call_click_entrypoint_direct(cli_args) + assert execute_dummy_cls.last_instance.main_formatter.format_name == 'tsv' + assert execute_dummy_cls.last_instance.run_query_calls[-1][0] == 'select 1' + + cli_args = main.CliArgs() + cli_args.execute = 'select 2\\G' + cli_args.format = 'table' + with pytest.raises(SystemExit): + call_click_entrypoint_direct(cli_args) + assert execute_dummy_cls.last_instance.main_formatter.format_name == 'ascii' + assert execute_dummy_cls.last_instance.run_query_calls[-1][0] == 'select 2' + + cli_args = main.CliArgs() + cli_args.execute = 'select 3' + cli_args.format = None + with pytest.raises(SystemExit): + call_click_entrypoint_direct(cli_args) + assert execute_dummy_cls.last_instance.main_formatter.format_name == 'tsv' + + def failing_run_query(self: Any, query: str, checkpoint: Any = None, new_line: bool = True) -> None: + raise RuntimeError('execute failed') + + FailingExecuteMyCli = cast(Any, type('FailingExecuteMyCli', (execute_dummy_cls,), {'run_query': failing_run_query})) + monkeypatch.setattr(main, 'MyCli', FailingExecuteMyCli) + cli_args = main.CliArgs() + cli_args.execute = 'select 4' + with pytest.raises(SystemExit): + call_click_entrypoint_direct(cli_args) + assert any('execute failed' in line for line in click_lines) + + +def test_configure_pager_and_refresh_completions(monkeypatch: pytest.MonkeyPatch) -> None: + cli = make_bare_mycli() + cli.my_cnf = {'client': {}, 'mysqld': {}} + cli.config = {'main': BoolSection({'pager': '', 'enable_pager': 'true'})} + cli.read_my_cnf = lambda cnf, keys: {'pager': 'less', 'skip-pager': ''} # type: ignore[assignment] + set_pager_calls: list[str] = [] + disable_calls: list[bool] = [] + monkeypatch.delenv('LESS', raising=False) + monkeypatch.setattr(main.special, 'set_pager', lambda pager: set_pager_calls.append(pager)) + monkeypatch.setattr(main.special, 'disable_pager', lambda: disable_calls.append(True)) + monkeypatch.setattr(output_module, 'WIN', True) + monkeypatch.setattr(shutil, 'which', lambda name: None) + main.MyCli.configure_pager(cli) + assert os.environ['LESS'] == '-RXF' + assert set_pager_calls == ['more'] + assert cli.explicit_pager is True + + class DisablePagerCalled(Exception): + pass + + def fake_disable_pager() -> None: + disable_calls.append(True) + assert cli.explicit_pager is False + raise DisablePagerCalled + + monkeypatch.setattr(main.special, 'disable_pager', fake_disable_pager) + cli.read_my_cnf = lambda cnf, keys: {'pager': '', 'skip-pager': '1'} # type: ignore[assignment] + with pytest.raises(DisablePagerCalled): + main.MyCli.configure_pager(cli) + + set_dbname_calls: list[str | None] = [] + refresh_calls: list[tuple[Any, Any, dict[str, Any]]] = [] + cli.completer = cast( + Any, + SimpleNamespace( + keyword_casing='upper', + set_dbname=lambda name: set_dbname_calls.append(name), + ), + ) + cli.main_formatter = SimpleNamespace(supported_formats=['ascii', 'csv']) + cli.completion_refresher = SimpleNamespace(refresh=lambda sql, callback, options: refresh_calls.append((sql, callback, options))) + cli.sqlexecute = SimpleNamespace(dbname='current_db') + cli._on_completions_refreshed = lambda new_completer: None # type: ignore[assignment] + + def fake_refresh(reset: bool = False) -> list[SQLResult]: + return main.MyCli.refresh_completions(cli, reset=reset) + + result = fake_refresh(reset=True) + assert set_dbname_calls == ['current_db'] + assert refresh_calls[0][2] == { + 'smart_completion': cli.smart_completion, + 'supported_formats': ['ascii', 'csv'], + 'keyword_casing': 'upper', + } + assert result[0].status == 'Auto-completion refresh started in the background.' diff --git a/test/pytests/test_naive_completion.py b/test/pytests/test_naive_completion.py new file mode 100644 index 00000000..fb7556d7 --- /dev/null +++ b/test/pytests/test_naive_completion.py @@ -0,0 +1,104 @@ +# type: ignore + +from prompt_toolkit.completion import Completion +from prompt_toolkit.document import Document +import pytest + +from test.utils import pygments_below + + +@pytest.fixture +def completer(): + import mycli.sqlcompleter as sqlcompleter + + return sqlcompleter.SQLCompleter(smart_completion=False) + + +@pytest.fixture +def complete_event(): + from unittest.mock import Mock + + return Mock() + + +def test_empty_string_completion(completer, complete_event): + text = "" + position = 0 + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) + assert result == list(map(Completion, completer.all_completions)) + + +def test_select_keyword_completion(completer, complete_event): + text = "SEL" + position = len("SEL") + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) + assert result == [Completion(text="SELECT", start_position=-3)] + + +def test_function_name_completion(completer, complete_event): + text = "SELECT MA" + position = len("SELECT MA") + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) + expected = [ + 'MAKEDATE', + 'MAKETIME', + 'MAKE_SET', + 'MANUAL', + 'MASTER', + 'MASTER_AUTO_POSITION', + 'MASTER_BIND', + 'MASTER_COMPRESSION_ALGORITHMS', + 'MASTER_CONNECT_RETRY', + 'MASTER_DELAY', + 'MASTER_HEARTBEAT_PERIOD', + 'MASTER_HOST', + 'MASTER_LOG_FILE', + 'MASTER_LOG_POS', + 'MASTER_PASSWORD', + 'MASTER_PORT', + 'MASTER_POS_WAIT', + 'MASTER_PUBLIC_KEY_PATH', + 'MASTER_RETRY_COUNT', + 'MASTER_SSL', + 'MASTER_SSL_CA', + 'MASTER_SSL_CAPATH', + 'MASTER_SSL_CERT', + 'MASTER_SSL_CIPHER', + 'MASTER_SSL_CRL', + 'MASTER_SSL_CRLPATH', + 'MASTER_SSL_KEY', + 'MASTER_SSL_VERIFY_SERVER_CERT', + 'MASTER_TLS_CIPHERSUITES', + 'MASTER_TLS_VERSION', + 'MASTER_USER', + 'MASTER_ZSTD_COMPRESSION_LEVEL', + 'MATCH', + 'MAX', + 'MAXVALUE', + 'MAX_CONNECTIONS_PER_HOUR', + 'MAX_QUERIES_PER_HOUR', + 'MAX_ROWS', + 'MAX_SIZE', + 'MAX_UPDATES_PER_HOUR', + 'MAX_USER_CONNECTIONS', + ] + + if pygments_below("2.20"): + expected.remove('MANUAL') + + assert sorted(x.text for x in result) == sorted(expected) + + +def test_column_name_completion(completer, complete_event): + text = "SELECT FROM users" + position = len("SELECT ") + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) + assert result == list(map(Completion, completer.all_completions)) + + +def test_special_name_completion(completer, complete_event): + text = "\\" + position = len("\\") + result = set(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) + # Special commands will NOT be suggested during naive completion mode. + assert result == set() diff --git a/test/pytests/test_output.py b/test/pytests/test_output.py new file mode 100644 index 00000000..47f7e0f5 --- /dev/null +++ b/test/pytests/test_output.py @@ -0,0 +1,232 @@ +from __future__ import annotations + +import itertools +import shutil +from typing import Any, cast + +import click +from configobj import ConfigObj +import prompt_toolkit +from prompt_toolkit.formatted_text import ANSI, FormattedText, to_plain_text +import pytest + +from mycli import output as output_module +from mycli.output import OutputMixin +from mycli.packages.sqlresult import SQLResult +from test.utils import DummyFormatter, FakeCursorBase, make_bare_mycli # type: ignore[attr-defined] + + +def test_output_timing_logs_and_prints_with_default_style(monkeypatch: pytest.MonkeyPatch) -> None: + cli = make_bare_mycli() + logged: list[Any] = [] + printed: list[tuple[Any, Any]] = [] + cli.log_output = lambda value: logged.append(value) # type: ignore[assignment] + monkeypatch.setattr(prompt_toolkit, 'print_formatted_text', lambda text, style=None: printed.append((text, style))) + + OutputMixin.output_timing(cli, '0.12 sec') + + assert logged == ['0.12 sec'] + assert to_plain_text(printed[0][0]) == '0.12 sec' + assert list(printed[0][0])[0][0].strip() == 'class:output.timing' + assert printed[0][1] == cli.ptoolkit_style + + +def test_output_timing_uses_warning_style(monkeypatch: pytest.MonkeyPatch) -> None: + cli = make_bare_mycli() + cli.log_output = lambda value: None # type: ignore[assignment] + printed: list[Any] = [] + monkeypatch.setattr(prompt_toolkit, 'print_formatted_text', lambda text, style=None: printed.append(text)) + + OutputMixin.output_timing(cli, '0.34 sec', is_warnings_style=True) + + assert list(printed[0])[0][0].strip() == 'class:warnings.timing' + + +def test_log_query_and_log_output_write_plain_text(tmp_path) -> None: + cli = make_bare_mycli() + logfile = tmp_path / 'audit.log' + + with logfile.open('w+', encoding='utf-8') as handle: + cli.logfile = handle + OutputMixin.log_query(cli, 'select 1') + OutputMixin.log_output(cli, ANSI('\x1b[31mhello\x1b[0m')) + handle.seek(0) + contents = handle.read() + + assert 'select 1' in contents + assert 'hello' in contents + assert '\x1b[31m' not in contents + + +def test_log_output_ignores_missing_logfile() -> None: + cli = make_bare_mycli() + cli.logfile = None + + OutputMixin.log_output(cli, 'nothing to write') + + +def test_echo_logs_and_prints(monkeypatch: pytest.MonkeyPatch) -> None: + cli = make_bare_mycli() + logged: list[str] = [] + printed: list[tuple[str, dict[str, Any]]] = [] + cli.log_output = lambda value: logged.append(value) # type: ignore[assignment] + monkeypatch.setattr(click, 'secho', lambda value, **kwargs: printed.append((value, kwargs))) + + OutputMixin.echo(cli, 'message', fg='red') + + assert logged == ['message'] + assert printed == [('message', {'fg': 'red'})] + + +def test_get_output_margin_renders_prompt_once_and_counts_status_lines(monkeypatch: pytest.MonkeyPatch) -> None: + cli = make_bare_mycli() + cli.prompt_lines = 0 + cli.prompt_format = 'ignored' + cli.prompt_session = None + cli.get_reserved_space = lambda: 2 # type: ignore[assignment] + monkeypatch.setattr(output_module.repl_mode, 'render_prompt_string', lambda *_args: FormattedText([('', 'one\ntwo')])) + monkeypatch.setattr(output_module.special, 'is_timing_enabled', lambda: True) + + margin = OutputMixin.get_output_margin(cli, 'ok\nwarning') + + assert margin == 7 + assert cli.prompt_lines == 2 + + +def test_output_writes_lines_sinks_and_status(monkeypatch: pytest.MonkeyPatch) -> None: + cli = make_bare_mycli() + cli.prompt_session = None + cli.explicit_pager = False + cli.get_output_margin = lambda status=None: 1 # type: ignore[assignment] + logged: list[Any] = [] + tee: list[str] = [] + once: list[str] = [] + pipe_once: list[str] = [] + printed_lines: list[str] = [] + printed_status: list[Any] = [] + cli.log_output = lambda value: logged.append(value) # type: ignore[assignment] + monkeypatch.setattr(output_module.special, 'write_tee', lambda value: tee.append(value)) + monkeypatch.setattr(output_module.special, 'write_once', lambda value: once.append(value)) + monkeypatch.setattr(output_module.special, 'write_pipe_once', lambda value: pipe_once.append(value)) + monkeypatch.setattr(output_module.special, 'is_redirected', lambda: False) + monkeypatch.setattr(output_module.special, 'is_pager_enabled', lambda: False) + monkeypatch.setattr(click, 'secho', lambda value, **_kwargs: printed_lines.append(value)) + monkeypatch.setattr(prompt_toolkit, 'print_formatted_text', lambda text, style=None: printed_status.append(text)) + + OutputMixin.output(cli, itertools.chain(['row 1', 'row 2']), SQLResult(status='done')) + + assert logged == ['row 1', 'row 2', 'done'] + assert tee == ['row 1', 'row 2'] + assert once == ['row 1', 'row 2'] + assert pipe_once == ['row 1', 'row 2'] + assert printed_lines == ['row 1', 'row 2'] + assert to_plain_text(printed_status[0]) == 'done' + assert list(printed_status[0])[0][0].strip() == 'class:output.status' + + +def test_output_uses_warning_status_style(monkeypatch: pytest.MonkeyPatch) -> None: + cli = make_bare_mycli() + cli.log_output = lambda value: None # type: ignore[assignment] + cli.get_output_margin = lambda status=None: 1 # type: ignore[assignment] + printed_status: list[Any] = [] + monkeypatch.setattr(prompt_toolkit, 'print_formatted_text', lambda text, style=None: printed_status.append(text)) + + OutputMixin.output(cli, itertools.chain([]), SQLResult(status='warning'), is_warnings_style=True) + + assert list(printed_status[0])[0][0].strip() == 'class:warnings.status' + + +def test_output_sends_buffer_to_pager_when_pager_is_explicit(monkeypatch: pytest.MonkeyPatch) -> None: + cli = make_bare_mycli() + cli.prompt_session = None + cli.explicit_pager = True + cli.log_output = lambda value: None # type: ignore[assignment] + cli.get_output_margin = lambda status=None: 1 # type: ignore[assignment] + paged_lines: list[str] = [] + monkeypatch.setattr(output_module.special, 'write_tee', lambda value: None) + monkeypatch.setattr(output_module.special, 'write_once', lambda value: None) + monkeypatch.setattr(output_module.special, 'write_pipe_once', lambda value: None) + monkeypatch.setattr(output_module.special, 'is_redirected', lambda: False) + monkeypatch.setattr(output_module.special, 'is_pager_enabled', lambda: True) + monkeypatch.setattr(click, 'echo_via_pager', lambda values: paged_lines.extend(list(values))) + monkeypatch.setattr(prompt_toolkit, 'print_formatted_text', lambda text, style=None: None) + + OutputMixin.output(cli, itertools.chain(['row 1', 'row 2']), SQLResult()) + + assert paged_lines == ['row 1\n', 'row 2\n'] + + +def test_configure_pager_prefers_my_cnf_pager_and_sets_less(monkeypatch: pytest.MonkeyPatch) -> None: + cli = make_bare_mycli() + cli.my_cnf = ConfigObj({'client': {'pager': 'my-pager'}}) + cli.config = ConfigObj({'main': {'pager': 'config-pager', 'enable_pager': 'True'}}) + cli.read_my_cnf = lambda cnf, keys: {'pager': 'my-pager', 'skip-pager': None} # type: ignore[assignment] + pager_calls: list[str] = [] + disabled: list[bool] = [] + monkeypatch.delenv('LESS', raising=False) + monkeypatch.setattr(output_module.special, 'set_pager', lambda value: pager_calls.append(value)) + monkeypatch.setattr(output_module.special, 'disable_pager', lambda: disabled.append(True)) + + OutputMixin.configure_pager(cli) + + assert pager_calls == ['my-pager'] + assert disabled == [] + assert cli.explicit_pager is True + assert output_module.os.environ['LESS'] == '-RXF' + + +def test_configure_pager_disables_when_skip_pager_is_set(monkeypatch: pytest.MonkeyPatch) -> None: + cli = make_bare_mycli() + cli.my_cnf = ConfigObj({'client': {}}) + cli.config = ConfigObj({'main': {'pager': '', 'enable_pager': 'True'}}) + cli.read_my_cnf = lambda cnf, keys: {'pager': None, 'skip-pager': '1'} # type: ignore[assignment] + disabled: list[bool] = [] + monkeypatch.setattr(output_module.special, 'set_pager', lambda value: None) + monkeypatch.setattr(output_module.special, 'disable_pager', lambda: disabled.append(True)) + + OutputMixin.configure_pager(cli) + + assert cli.explicit_pager is False + assert disabled == [True] + + +def test_format_sqlresult_uses_redirect_formatter_and_appends_preamble_postamble() -> None: + cli = make_bare_mycli() + cli.main_formatter = DummyFormatter() + cli.redirect_formatter = DummyFormatter() + result = SQLResult(preamble='before', header=['id'], rows=[(1,)], postamble='after') + + formatted = list(OutputMixin.format_sqlresult(cli, result, is_redirected=True)) + + assert formatted == ['before', 'plain output', 'after'] + assert cli.main_formatter.calls == [] + assert cli.redirect_formatter.calls + + +def test_format_sqlresult_for_cursor_sets_column_types_and_alignment(monkeypatch: pytest.MonkeyPatch) -> None: + cli = make_bare_mycli() + cli.main_formatter = DummyFormatter() + monkeypatch.setattr(output_module, 'Cursor', FakeCursorBase) + rows = FakeCursorBase(rows=[(1, 'name')], rowcount=1, description=[('id', 3), ('name', 253)]) + result = SQLResult(header=['id', 'name'], rows=cast(Any, rows)) + + assert list(OutputMixin.format_sqlresult(cli, result, numeric_alignment='left')) == ['plain output'] + + _, kwargs = cli.main_formatter.calls[-1] + assert kwargs['column_types'] == [int, str] + assert kwargs['colalign'] == ['left', 'left'] + + +def test_format_sqlresult_switches_to_vertical_when_first_line_is_too_wide() -> None: + cli = make_bare_mycli() + cli.main_formatter = DummyFormatter() + result = SQLResult(header=['id'], rows=[(1,)]) + + assert list(OutputMixin.format_sqlresult(cli, result, max_width=2)) == ['vertical output'] + + +def test_get_reserved_space_caps_ratio(monkeypatch: pytest.MonkeyPatch) -> None: + cli = make_bare_mycli() + monkeypatch.setattr(shutil, 'get_terminal_size', lambda *args, **kwargs: (120, 40)) + + assert OutputMixin.get_reserved_space(cli) == 8 diff --git a/test/pytests/test_ptoolkit_fzf.py b/test/pytests/test_ptoolkit_fzf.py new file mode 100644 index 00000000..2bbaa001 --- /dev/null +++ b/test/pytests/test_ptoolkit_fzf.py @@ -0,0 +1,192 @@ +from types import SimpleNamespace +from typing import Any, cast + +import pytest + +from mycli.packages.ptoolkit import fzf as fzf_module +from mycli.packages.ptoolkit.history import FileHistoryWithTimestamp + + +class DummyHistory(FileHistoryWithTimestamp): + def __init__(self, items: list[tuple[str, str]]) -> None: + self._items = items + + def load_history_with_timestamp(self) -> list[tuple[str, str]]: + return self._items + + +def make_event(history: Any) -> SimpleNamespace: + buffer = SimpleNamespace(history=history, text='original', cursor_position=0) + return SimpleNamespace( + current_buffer=buffer, + app=SimpleNamespace(), + ) + + +def test_fzf_init_and_is_available(monkeypatch) -> None: + init_calls: list[bool] = [] + + monkeypatch.setattr(fzf_module, 'which', lambda executable: '/usr/bin/fzf' if executable == 'fzf' else None) + monkeypatch.setattr(fzf_module.FzfPrompt, '__init__', lambda self: init_calls.append(True)) + + fzf = fzf_module.Fzf() + + assert fzf.executable == '/usr/bin/fzf' + assert fzf.is_available() is True + assert init_calls == [True] + + +def test_fzf_init_without_executable_skips_super(monkeypatch) -> None: + init_calls: list[bool] = [] + + monkeypatch.setattr(fzf_module, 'which', lambda executable: None) + monkeypatch.setattr(fzf_module.FzfPrompt, '__init__', lambda self: init_calls.append(True)) + + fzf = fzf_module.Fzf() + + assert fzf.executable is None + assert fzf.is_available() is False + assert init_calls == [] + + +def test_search_history_falls_back_to_prompt_toolkit_search(monkeypatch) -> None: + calls: list[dict[str, Any]] = [] + event = make_event(history=object()) + + monkeypatch.setattr( + fzf_module.search, + 'start_search', + lambda **kwargs: calls.append(kwargs), + ) + + fzf_module.search_history(cast(Any, event), incremental=True) + + assert calls == [{'direction': fzf_module.search.SearchDirection.BACKWARD}] + + +def test_search_history_falls_back_when_fzf_unavailable_or_history_type_is_wrong(monkeypatch) -> None: + calls: list[dict[str, Any]] = [] + unavailable_event = make_event(history=DummyHistory([])) + wrong_history_event = make_event(history=[]) + + class UnavailableFzf: + def is_available(self) -> bool: + return False + + monkeypatch.setattr( + fzf_module.search, + 'start_search', + lambda **kwargs: calls.append(kwargs), + ) + + monkeypatch.setattr(fzf_module, 'Fzf', UnavailableFzf) + fzf_module.search_history(cast(Any, unavailable_event)) + + class AvailableFzf: + def is_available(self) -> bool: + return True + + monkeypatch.setattr(fzf_module, 'Fzf', AvailableFzf) + fzf_module.search_history(cast(Any, wrong_history_event)) + + assert calls == [ + {'direction': fzf_module.search.SearchDirection.BACKWARD}, + {'direction': fzf_module.search.SearchDirection.BACKWARD}, + ] + + +def test_search_history_formats_preview_updates_buffer_and_deduplicates(monkeypatch) -> None: + prompt_calls: list[dict[str, Any]] = [] + invalidated_apps: list[Any] = [] + + history = DummyHistory([ + ('SELECT 1\nFROM dual', '2026-01-02 03:04:05.678'), + ('SELECT 1 FROM dual', '2026-01-01 00:00:00'), + ('SELECT 2', '2026-01-03 12:00:00'), + ]) + event = make_event(history=history) + + class PromptingFzf: + def is_available(self) -> bool: + return True + + def prompt(self, items: list[str], fzf_options: str) -> list[str]: + prompt_calls.append({'items': items, 'options': fzf_options}) + return [items[0]] + + monkeypatch.setattr(fzf_module, 'Fzf', PromptingFzf) + monkeypatch.setattr( + fzf_module, + 'which', + lambda executable: '/usr/bin/pygmentize' if executable == 'pygmentize' else None, + ) + monkeypatch.setattr(fzf_module, 'safe_invalidate_display', lambda app: invalidated_apps.append(app)) + + fzf_module.search_history( + cast(Any, event), + highlight_preview=True, + highlight_style='monokai style', + ) + + assert prompt_calls == [ + { + 'items': [ + '2026-01-02 03:04:05 SELECT 1 FROM dual', + '2026-01-03 12:00:00 SELECT 2', + ], + 'options': '--info=hidden --scheme=history --tiebreak=index --bind=ctrl-r:up,alt-r:up ' + '--preview-window=down:wrap:nohidden --no-height ' + "--preview=\"printf '%s' {} | pygmentize -l mysql -P style='monokai style'\"", + } + ] + assert invalidated_apps == [event.app] + assert event.current_buffer.text == 'SELECT 1\nFROM dual' + assert event.current_buffer.cursor_position == len('SELECT 1\nFROM dual') + + +@pytest.mark.parametrize( + ("highlight_preview", "pygmentize_available"), + [ + (False, False), + (False, True), + (True, False), + ], +) +def test_search_history_without_result_keeps_buffer_and_uses_plain_preview( + monkeypatch, + highlight_preview: bool, + pygmentize_available: bool, +) -> None: + prompt_calls: list[dict[str, Any]] = [] + invalidated_apps: list[Any] = [] + + event = make_event(history=DummyHistory([('SELECT 1', '2026-01-01 00:00:00')])) + + class PromptingFzf: + def is_available(self) -> bool: + return True + + def prompt(self, items: list[str], fzf_options: str) -> list[str]: + prompt_calls.append({'items': items, 'options': fzf_options}) + return [] + + monkeypatch.setattr(fzf_module, 'Fzf', PromptingFzf) + monkeypatch.setattr( + fzf_module, + 'which', + lambda executable: '/usr/bin/pygmentize' if pygmentize_available and executable == 'pygmentize' else None, + ) + monkeypatch.setattr(fzf_module, 'safe_invalidate_display', lambda app: invalidated_apps.append(app)) + + fzf_module.search_history(cast(Any, event), highlight_preview=highlight_preview) + + assert prompt_calls == [ + { + 'items': ['2026-01-01 00:00:00 SELECT 1'], + 'options': '--info=hidden --scheme=history --tiebreak=index --bind=ctrl-r:up,alt-r:up ' + "--preview-window=down:wrap:nohidden --no-height --preview=\"printf '%s' {}\"", + } + ] + assert invalidated_apps == [event.app] + assert event.current_buffer.text == 'original' + assert event.current_buffer.cursor_position == 0 diff --git a/test/pytests/test_ptoolkit_history.py b/test/pytests/test_ptoolkit_history.py new file mode 100644 index 00000000..ce54b590 --- /dev/null +++ b/test/pytests/test_ptoolkit_history.py @@ -0,0 +1,72 @@ +# type: ignore + +from pathlib import Path + +from mycli.packages.ptoolkit import history as history_module +from mycli.packages.ptoolkit.history import FileHistoryWithTimestamp + + +def test_file_history_with_timestamp_sets_filename(tmp_path: Path) -> None: + history_path = tmp_path / 'history.txt' + + history = FileHistoryWithTimestamp(history_path) + + assert history.filename == history_path + + +def test_append_string_caches_and_stores_non_password_statement(tmp_path: Path, monkeypatch) -> None: + history = FileHistoryWithTimestamp(tmp_path / 'history.txt') + stored: list[str] = [] + monkeypatch.setattr(history, 'store_string', stored.append) + + history.append_string('SELECT 1') + + assert history.get_strings()[0] == 'SELECT 1' + assert stored == ['SELECT 1'] + + +def test_append_string_does_not_store_password_change(tmp_path: Path, monkeypatch) -> None: + history = FileHistoryWithTimestamp(tmp_path / 'history.txt') + stored: list[str] = [] + monkeypatch.setattr(history, 'store_string', stored.append) + monkeypatch.setattr(history_module, 'is_password_change', lambda string: True) + + history.append_string("SET PASSWORD = 'secret'") + + assert history.get_strings()[0] == "SET PASSWORD = 'secret'" + assert stored == [] + + +def test_load_history_with_timestamp_returns_empty_when_file_is_missing(tmp_path: Path) -> None: + history = FileHistoryWithTimestamp(tmp_path / 'missing-history.txt') + + assert history.load_history_with_timestamp() == [] + + +def test_load_history_with_timestamp_parses_and_reverses_entries(tmp_path: Path) -> None: + history_path = tmp_path / 'history.txt' + history_path.write_text( + '# 2026-04-02 10:00:00\n+SELECT 1\n+FROM dual\n\n# 2026-04-02 11:00:00\n+SHOW DATABASES\n', + encoding='utf-8', + ) + + history = FileHistoryWithTimestamp(history_path) + + assert history.load_history_with_timestamp() == [ + ('SHOW DATABASES', '2026-04-02 11:00:00'), + ('SELECT 1\nFROM dual', '2026-04-02 10:00:00'), + ] + + +def test_load_history_with_timestamp_ignores_empty_separator_blocks(tmp_path: Path) -> None: + history_path = tmp_path / 'history.txt' + history_path.write_text( + '# 2026-04-02 10:00:00\n\n# 2026-04-02 11:00:00\n+SELECT 1\n\ngarbage separator\n', + encoding='utf-8', + ) + + history = FileHistoryWithTimestamp(history_path) + + assert history.load_history_with_timestamp() == [ + ('SELECT 1', '2026-04-02 11:00:00'), + ] diff --git a/test/pytests/test_ptoolkit_utils.py b/test/pytests/test_ptoolkit_utils.py new file mode 100644 index 00000000..cd3773d1 --- /dev/null +++ b/test/pytests/test_ptoolkit_utils.py @@ -0,0 +1,41 @@ +from dataclasses import dataclass, field +from typing import Any, cast + +from mycli.packages.ptoolkit import utils as ptoolkit_utils + + +@dataclass +class DummyApp: + print_calls: list[str] = field(default_factory=list) + + def print_text(self, text: str) -> None: + self.print_calls.append(text) + + +def test_safe_invalidate_display_runs_empty_terminal_print(monkeypatch) -> None: + app = DummyApp() + callbacks: list[object] = [] + + def fake_run_in_terminal(callback) -> None: + callbacks.append(callback) + callback() + + monkeypatch.setattr(ptoolkit_utils, 'run_in_terminal', fake_run_in_terminal) + + ptoolkit_utils.safe_invalidate_display(cast(Any, app)) + + assert len(callbacks) == 1 + assert app.print_calls == [''] + + +def test_safe_invalidate_display_swallows_runtime_error(monkeypatch) -> None: + app = DummyApp() + + def fail_run_in_terminal(_callback) -> None: + raise RuntimeError('application is exiting') + + monkeypatch.setattr(ptoolkit_utils, 'run_in_terminal', fail_run_in_terminal) + + ptoolkit_utils.safe_invalidate_display(cast(Any, app)) + + assert app.print_calls == [] diff --git a/test/pytests/test_schema_prefetcher.py b/test/pytests/test_schema_prefetcher.py new file mode 100644 index 00000000..0eebe8b6 --- /dev/null +++ b/test/pytests/test_schema_prefetcher.py @@ -0,0 +1,376 @@ +# type: ignore + +import threading +from types import SimpleNamespace +from unittest.mock import MagicMock + +from mycli import schema_prefetcher as schema_prefetcher_module +from mycli.schema_prefetcher import SchemaPrefetcher, parse_prefetch_config +from mycli.sqlcompleter import SQLCompleter + + +def test_parse_prefetch_config_never() -> None: + assert parse_prefetch_config('never', []) == [] + assert parse_prefetch_config('NEVER', ['ignored', 'values']) == [] + assert parse_prefetch_config(' never ', []) == [] + + +def test_parse_prefetch_config_always() -> None: + assert parse_prefetch_config('always', []) is None + assert parse_prefetch_config('ALWAYS', []) is None + assert parse_prefetch_config(' always ', ['ignored']) is None + + +def test_parse_prefetch_config_listed() -> None: + assert parse_prefetch_config('listed', ['foo', 'bar', 'baz']) == ['foo', 'bar', 'baz'] + assert parse_prefetch_config('LISTED', ['solo']) == ['solo'] + assert parse_prefetch_config('listed', []) == [] + + +def test_parse_prefetch_config_unknown_mode_falls_back_to_always() -> None: + assert parse_prefetch_config('unknown', ['ignored']) is None + + +def make_mycli( + prefetch_mode: str = 'listed', + prefetch_list: list[str] | None = None, + dbname: str = 'current', + databases=None, +): + if prefetch_list is None: + prefetch_list = [] + if databases is None: + databases = ['current', 'other1', 'other2'] + completer = SQLCompleter(smart_completion=True) + completer.set_dbname(dbname) + sqlexecute = SimpleNamespace( + dbname=dbname, + user='u', + password='p', + host='h', + port=3306, + socket=None, + character_set='utf8mb4', + local_infile=False, + ssl=None, + ssh_user=None, + ssh_host=None, + ssh_port=22, + ssh_password=None, + ssh_key_filename=None, + databases=MagicMock(return_value=list(databases)), + ) + return SimpleNamespace( + completer=completer, + sqlexecute=sqlexecute, + prefetch_schemas_mode=prefetch_mode, + prefetch_schemas_list=prefetch_list, + _completer_lock=threading.Lock(), + prompt_session=None, + ) + + +def _fake_executor_factory(per_schema_tables, databases=None): + """Build an executor stub whose schema-aware methods yield prebuilt rows.""" + + def make(*_args, **_kwargs): + executor = MagicMock() + executor.databases.return_value = list(databases) if databases is not None else [] + executor.table_columns.side_effect = lambda schema=None: iter(per_schema_tables.get(schema, [])) + executor.foreign_keys.side_effect = lambda schema=None: iter([]) + executor.enum_values.side_effect = lambda schema=None: iter([]) + executor.functions.side_effect = lambda schema=None: iter([]) + executor.procedures.side_effect = lambda schema=None: iter([]) + executor.close = MagicMock() + return executor + + return make + + +def test_start_configured_skips_current_and_prefetches_others(monkeypatch): + mycli = make_mycli(prefetch_mode='listed', prefetch_list=['other1', 'current', 'other2']) + tables = { + 'other1': [('users', 'id'), ('users', 'email')], + 'other2': [('orders', 'id')], + } + monkeypatch.setattr(schema_prefetcher_module, 'SQLExecute', _fake_executor_factory(tables)) + + prefetcher = SchemaPrefetcher(mycli) + prefetcher.start_configured() + assert prefetcher._thread is not None + prefetcher._thread.join(timeout=5) + + tables_meta = mycli.completer.dbmetadata['tables'] + assert 'other1' in tables_meta + assert 'other2' in tables_meta + # Current schema must be untouched by the prefetcher. + assert 'current' not in tables_meta + assert set(tables_meta['other1'].keys()) == {'users'} + # Column list starts with '*' marker and contains escaped column names. + assert tables_meta['other1']['users'][0] == '*' + assert 'id' in tables_meta['other1']['users'] + + +def test_start_configured_all_resolves_from_databases(monkeypatch): + mycli = make_mycli(prefetch_mode='always', databases=['current', 'alpha', 'beta']) + tables = { + 'alpha': [('t_a', 'c')], + 'beta': [('t_b', 'c')], + } + monkeypatch.setattr( + schema_prefetcher_module, + 'SQLExecute', + _fake_executor_factory(tables, databases=['current', 'alpha', 'beta']), + ) + + prefetcher = SchemaPrefetcher(mycli) + prefetcher.start_configured() + assert prefetcher._thread is not None + prefetcher._thread.join(timeout=5) + + tables_meta = mycli.completer.dbmetadata['tables'] + assert 'alpha' in tables_meta + assert 'beta' in tables_meta + assert 'current' not in tables_meta + + +def test_start_configured_noop_when_disabled(monkeypatch): + mycli = make_mycli(prefetch_mode='never') + make_executor = MagicMock() + monkeypatch.setattr(schema_prefetcher_module, 'SQLExecute', make_executor) + + prefetcher = SchemaPrefetcher(mycli) + prefetcher.start_configured() + + assert prefetcher._thread is None + make_executor.assert_not_called() + + +def test_prefetch_schema_now_loads_single_schema(monkeypatch): + mycli = make_mycli(prefetch_mode='never') + tables = {'target': [('t1', 'c1')]} + monkeypatch.setattr(schema_prefetcher_module, 'SQLExecute', _fake_executor_factory(tables)) + + prefetcher = SchemaPrefetcher(mycli) + prefetcher.prefetch_schema_now('target') + assert prefetcher._thread is not None + prefetcher._thread.join(timeout=5) + + assert 'target' in mycli.completer.dbmetadata['tables'] + + +def test_stop_interrupts_running_prefetch(monkeypatch): + mycli = make_mycli(prefetch_mode='listed', prefetch_list=['a', 'b']) + monkeypatch.setattr( + schema_prefetcher_module, + 'SQLExecute', + _fake_executor_factory({'a': [], 'b': []}), + ) + + prefetcher = SchemaPrefetcher(mycli) + # Immediately cancel before any work runs. + prefetcher._cancel.set() + prefetcher._start(['a', 'b']) + if prefetcher._thread is not None: + prefetcher._thread.join(timeout=5) + # stop() must be idempotent and leave the prefetcher ready to run again. + prefetcher.stop() + assert prefetcher._thread is None + + +def test_start_skips_schemas_already_in_completer(monkeypatch): + """Previously-loaded schemas must not be re-fetched on refresh.""" + mycli = make_mycli(prefetch_mode='listed', prefetch_list=['keep', 'fresh']) + # Simulate a schema that was already loaded (e.g., preserved via + # copy_other_schemas_from after a completion refresh). + mycli.completer.dbmetadata['tables']['keep'] = {'cached_table': ['*', 'c1']} + + executor_calls: list[str] = [] + + def make(*_args, **_kwargs): + executor = MagicMock() + + def _track(schema=None): + executor_calls.append(schema) + return iter([]) + + executor.table_columns.side_effect = _track + executor.foreign_keys.side_effect = lambda schema=None: iter([]) + executor.enum_values.side_effect = lambda schema=None: iter([]) + executor.functions.side_effect = lambda schema=None: iter([]) + executor.procedures.side_effect = lambda schema=None: iter([]) + executor.close = MagicMock() + return executor + + monkeypatch.setattr(schema_prefetcher_module, 'SQLExecute', make) + + prefetcher = SchemaPrefetcher(mycli) + prefetcher.start_configured() + if prefetcher._thread is not None: + prefetcher._thread.join(timeout=5) + + # Only 'fresh' is queried; 'keep' and 'current' are skipped. + assert executor_calls == ['fresh'] + # Cached data for 'keep' is untouched. + assert mycli.completer.dbmetadata['tables']['keep'] == {'cached_table': ['*', 'c1']} + + +def test_is_prefetching_and_clear_loaded() -> None: + mycli = make_mycli() + prefetcher = SchemaPrefetcher(mycli) + + assert prefetcher.is_prefetching() is False + + prefetcher._loaded.update({'alpha', 'beta'}) + prefetcher.clear_loaded() + + class FakeThread: + def is_alive(self) -> bool: + return True + + prefetcher._thread = FakeThread() + assert prefetcher.is_prefetching() is True + assert prefetcher._loaded == set() + + +def test_stop_joins_alive_thread_and_resets_state() -> None: + mycli = make_mycli() + prefetcher = SchemaPrefetcher(mycli) + old_cancel = prefetcher._cancel + + class FakeThread: + def __init__(self) -> None: + self.join_timeout: float | None = None + + def is_alive(self) -> bool: + return True + + def join(self, timeout: float) -> None: + self.join_timeout = timeout + + fake_thread = FakeThread() + prefetcher._thread = fake_thread + + prefetcher.stop(timeout=1.5) + + assert old_cancel.is_set() + assert fake_thread.join_timeout == 1.5 + assert prefetcher._thread is None + assert prefetcher._cancel is not old_cancel + + +def test_prefetch_schema_now_ignores_empty_schema(monkeypatch) -> None: + mycli = make_mycli() + prefetcher = SchemaPrefetcher(mycli) + stop = MagicMock() + start = MagicMock() + monkeypatch.setattr(prefetcher, 'stop', stop) + monkeypatch.setattr(prefetcher, '_start', start) + + prefetcher.prefetch_schema_now('') + + stop.assert_not_called() + start.assert_not_called() + + +def test_run_returns_when_database_listing_fails(monkeypatch) -> None: + mycli = make_mycli() + prefetcher = SchemaPrefetcher(mycli) + executor = MagicMock() + executor.databases.side_effect = RuntimeError('boom') + executor.close = MagicMock() + invalidate = MagicMock() + monkeypatch.setattr(prefetcher, '_make_executor', lambda: executor) + monkeypatch.setattr(prefetcher, '_invalidate_app', invalidate) + + prefetcher._run(None) + + executor.databases.assert_called_once_with() + executor.close.assert_called_once_with() + invalidate.assert_called_once_with() + + +def test_run_returns_when_cancelled_before_prefetch(monkeypatch) -> None: + mycli = make_mycli() + prefetcher = SchemaPrefetcher(mycli) + executor = MagicMock() + executor.close = MagicMock() + prefetch = MagicMock() + invalidate = MagicMock() + prefetcher._cancel.set() + monkeypatch.setattr(prefetcher, '_make_executor', lambda: executor) + monkeypatch.setattr(prefetcher, '_prefetch_one', prefetch) + monkeypatch.setattr(prefetcher, '_invalidate_app', invalidate) + + prefetcher._run(['schema1']) + + prefetch.assert_not_called() + assert prefetcher._loaded == set() + executor.close.assert_called_once_with() + invalidate.assert_called_once_with() + + +def test_run_logs_prefetch_error_and_continues(monkeypatch) -> None: + mycli = make_mycli() + prefetcher = SchemaPrefetcher(mycli) + executor = MagicMock() + executor.close = MagicMock() + invalidate = MagicMock() + calls: list[str] = [] + + def fake_prefetch(_executor, schema: str) -> None: + calls.append(schema) + if schema == 'bad': + raise RuntimeError('boom') + + monkeypatch.setattr(prefetcher, '_make_executor', lambda: executor) + monkeypatch.setattr(prefetcher, '_prefetch_one', fake_prefetch) + monkeypatch.setattr(prefetcher, '_invalidate_app', invalidate) + + prefetcher._run(['bad', 'good']) + + assert calls == ['bad', 'good'] + assert prefetcher._loaded == {'good'} + executor.close.assert_called_once_with() + invalidate.assert_called_once_with() + + +def test_prefetch_one_loads_foreign_keys_enums_functions_and_procedures(monkeypatch) -> None: + mycli = make_mycli() + load_schema_metadata = MagicMock() + mycli.completer.load_schema_metadata = load_schema_metadata + prefetcher = SchemaPrefetcher(mycli) + invalidate = MagicMock() + monkeypatch.setattr(prefetcher, '_invalidate_app', invalidate) + + executor = MagicMock() + executor.table_columns.return_value = iter([('orders', 'id')]) + executor.foreign_keys.return_value = iter([('orders', 'user_id', 'users', 'id')]) + executor.enum_values.return_value = iter([('orders', 'status', ['pending', 'shipped'])]) + executor.functions.return_value = iter([(), ('calc_tax',), (None,)]) + executor.procedures.return_value = iter([None, ('rebuild_cache',), ('',)]) + + prefetcher._prefetch_one(executor, 'analytics') + + load_schema_metadata.assert_called_once_with( + schema='analytics', + table_columns={'orders': ['*', 'id']}, + foreign_keys={ + 'tables': {'orders': {'users'}, 'users': {'orders'}}, + 'relations': [('orders', 'user_id', 'users', 'id')], + }, + enum_values={'orders': {'status': ['pending', 'shipped']}}, + functions={'calc_tax': None}, + procedures={'rebuild_cache': None}, + ) + invalidate.assert_called_once_with() + + +def test_invalidate_app_calls_prompt_session_app() -> None: + mycli = make_mycli() + mycli.prompt_session = SimpleNamespace(app=SimpleNamespace(invalidate=MagicMock())) + prefetcher = SchemaPrefetcher(mycli) + + prefetcher._invalidate_app() + + mycli.prompt_session.app.invalidate.assert_called_once_with() diff --git a/test/pytests/test_smart_completion_public_schema_only.py b/test/pytests/test_smart_completion_public_schema_only.py new file mode 100644 index 00000000..44a96741 --- /dev/null +++ b/test/pytests/test_smart_completion_public_schema_only.py @@ -0,0 +1,1202 @@ +# type: ignore + +import os.path +from unittest.mock import patch + +from prompt_toolkit.completion import Completion +from prompt_toolkit.document import Document +import pytest + +import mycli.packages.special.main as special +from test.utils import pygments_below + +metadata = { + "users": ["id", "email", "first_name", "last_name"], + "orders": ["id", "ordered_date", "status"], + "select": ["id", "insert", "ABC"], + "réveillé": ["id", "insert", "ABC"], + "time_zone": ["Time_zone_id"], + "time_zone_leap_second": ["Time_zone_id"], + "time_zone_name": ["Time_zone_id"], + "time_zone_transition": ["Time_zone_id"], + "time_zone_transition_type": ["Time_zone_id"], +} + + +@pytest.fixture +def completer(): + import mycli.sqlcompleter as sqlcompleter + + comp = sqlcompleter.SQLCompleter(smart_completion=True) + + tables, columns = [], [] + + for table, cols in metadata.items(): + tables.append((table,)) + columns.extend([(table, col) for col in cols]) + + databases = ["test", "test 2"] + + for db in databases: + comp.extend_schemata(db) + comp.extend_database_names(databases) + comp.set_dbname("test") + comp.extend_relations(tables, kind="tables") + comp.extend_columns(columns, kind="tables") + comp.extend_enum_values([("orders", "status", ["pending", "shipped"])]) + comp.extend_special_commands(special.COMMANDS) + + return comp + + +@pytest.fixture +def empty_completer(): + import mycli.sqlcompleter as sqlcompleter + + comp = sqlcompleter.SQLCompleter(smart_completion=True) + + tables, columns = [], [] + + for table, cols in metadata.items(): + tables.append((table,)) + columns.extend([(table, col) for col in cols]) + + db = 'empty' + + comp.extend_schemata(db) + comp.extend_database_names([db]) + comp.set_dbname(db) + comp.extend_special_commands(special.COMMANDS) + + return comp + + +@pytest.fixture +def complete_event(): + from unittest.mock import Mock + + return Mock() + + +def test_use_database_completion(completer, complete_event): + text = "USE " + position = len(text) + special.register_special_command( + ..., + 'use', + '\\u [database]', + 'Change to a new database.', + aliases=[special.SpecialCommandAlias('\\u', case_sensitive=False)], + ) + result = completer.get_completions(Document(text=text, cursor_position=position), complete_event) + assert list(result) == [ + Completion(text="test", start_position=0), + Completion(text="`test 2`", start_position=0), + ] + + +def test_special_name_completion(completer, complete_event): + text = "\\d" + position = len("\\d") + result = completer.get_completions(Document(text=text, cursor_position=position), complete_event) + assert list(result) == [Completion(text="\\dt", start_position=-2)] + + +def test_empty_string_completion(completer, complete_event): + text = "" + position = 0 + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) + assert list(map(Completion, completer.special_commands + completer.keywords)) == result + + +def test_select_keyword_completion(completer, complete_event): + text = "SEL" + position = len("SEL") + result = completer.get_completions(Document(text=text, cursor_position=position), complete_event) + assert list(result) == [ + Completion(text='SELECT', start_position=-3), + Completion(text='SERIAL', start_position=-3), + Completion(text='MASTER_LOG_FILE', start_position=-3), + Completion(text='MASTER_LOG_POS', start_position=-3), + Completion(text='MASTER_TLS_CIPHERSUITES', start_position=-3), + Completion(text='MASTER_TLS_VERSION', start_position=-3), + Completion(text='SCHEDULE', start_position=-3), + Completion(text='SERIALIZABLE', start_position=-3), + ] + + +def test_select_star(completer, complete_event): + text = "SELECT * " + position = len(text) + result = completer.get_completions(Document(text=text, cursor_position=position), complete_event) + assert list(result) == list(map(Completion, completer.keywords)) + + +def test_introducer_completion(completer, complete_event): + completer.extend_character_sets([('latin1',), ('utf8mb4',)]) + text = 'SELECT _' + position = len(text) + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) + result_text = [item.text for item in result] + assert '_latin1' in result_text + assert '_utf8mb4' in result_text + + +def test_collation_completion(completer, complete_event): + completer.extend_collations([('utf16le_bin',), ('utf8mb4_unicode_ci',)]) + text = 'SELECT "text" COLLATE ' + position = len(text) + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) + result_text = [item.text for item in result] + assert 'utf16le_bin' in result_text + assert 'utf8mb4_unicode_ci' in result_text + + +def test_transcoding_completion_1(completer, complete_event): + completer.extend_character_sets([('latin1',), ('utf8mb4',)]) + text = 'SELECT CONVERT("text" USING ' + position = len(text) + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) + result_text = [item.text for item in result] + assert 'latin1' in result_text + assert 'utf8mb4' in result_text + + +def test_transcoding_completion_2(completer, complete_event): + completer.extend_character_sets([('utf8mb3',), ('utf8mb4',)]) + text = 'SELECT CONVERT("text" USING u' + position = len(text) + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) + result_text = [item.text for item in result] + assert 'utf8mb3' in result_text + assert 'utf8mb4' in result_text + + +def test_transcoding_completion_3(completer, complete_event): + completer.extend_character_sets([('latin1',), ('utf8mb4',)]) + text = 'SELECT CAST("text" AS CHAR CHARACTER SET ' + position = len(text) + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) + result_text = [item.text for item in result] + assert 'latin1' in result_text + assert 'utf8mb4' in result_text + + +def test_transcoding_completion_4(completer, complete_event): + completer.extend_character_sets([('utf8mb3',), ('utf8mb4',)]) + text = 'SELECT CAST("text" AS CHAR CHARACTER SET u' + position = len(text) + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) + result_text = [item.text for item in result] + assert 'utf8mb3' in result_text + assert 'utf8mb4' in result_text + + +def test_where_transcoding_completion_1(completer, complete_event): + completer.extend_character_sets([('latin1',), ('utf8mb4',)]) + text = 'SELECT * FROM users WHERE CONVERT(email USING ' + position = len(text) + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) + result_text = [item.text for item in result] + assert 'latin1' in result_text + assert 'utf8mb4' in result_text + + +def test_where_transcoding_completion_2(completer, complete_event): + completer.extend_character_sets([('latin1',), ('utf8mb4',)]) + text = 'SELECT * FROM users WHERE CAST(email AS CHAR CHARACTER SET ' + position = len(text) + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) + result_text = [item.text for item in result] + assert 'latin1' in result_text + assert 'utf8mb4' in result_text + + +def test_table_completion(completer, complete_event): + text = "SELECT * FROM " + position = len(text) + result = completer.get_completions(Document(text=text, cursor_position=position), complete_event) + assert list(result) == [ + Completion(text="users", start_position=0), + Completion(text="orders", start_position=0), + Completion(text="`select`", start_position=0), + Completion(text="`réveillé`", start_position=0), + Completion(text="time_zone", start_position=0), + Completion(text="time_zone_leap_second", start_position=0), + Completion(text="time_zone_name", start_position=0), + Completion(text="time_zone_transition", start_position=0), + Completion(text="time_zone_transition_type", start_position=0), + Completion(text="test", start_position=0), + Completion(text="`test 2`", start_position=0), + ] + + +def test_select_filtered_table_completion(completer, complete_event): + text = "SELECT ABC FROM " + position = len(text) + result = completer.get_completions(Document(text=text, cursor_position=position), complete_event) + assert list(result) == [ + Completion(text="`select`", start_position=0), + Completion(text="`réveillé`", start_position=0), + Completion(text="users", start_position=0), + Completion(text="orders", start_position=0), + Completion(text="time_zone", start_position=0), + Completion(text="time_zone_leap_second", start_position=0), + Completion(text="time_zone_name", start_position=0), + Completion(text="time_zone_transition", start_position=0), + Completion(text="time_zone_transition_type", start_position=0), + Completion(text="test", start_position=0), + Completion(text="`test 2`", start_position=0), + ] + + +def test_sub_select_filtered_table_completion(completer, complete_event): + text = "SELECT * FROM (SELECT ordered_date FROM " + position = len(text) + result = completer.get_completions(Document(text=text, cursor_position=position), complete_event) + assert list(result) == [ + Completion(text="orders", start_position=0), + Completion(text="users", start_position=0), + Completion(text="`select`", start_position=0), + Completion(text="`réveillé`", start_position=0), + Completion(text="time_zone", start_position=0), + Completion(text="time_zone_leap_second", start_position=0), + Completion(text="time_zone_name", start_position=0), + Completion(text="time_zone_transition", start_position=0), + Completion(text="time_zone_transition_type", start_position=0), + Completion(text="test", start_position=0), + Completion(text="`test 2`", start_position=0), + ] + + +def test_enum_value_completion(completer, complete_event): + text = "SELECT * FROM orders WHERE status = " + position = len(text) + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) + assert result == [ + Completion(text="'pending'", start_position=0), + Completion(text="'shipped'", start_position=0), + ] + + +def test_function_name_completion(completer, complete_event): + text = "SELECT MA" + position = len("SELECT MA") + result = completer.get_completions(Document(text=text, cursor_position=position), complete_event) + assert list(result) == [ + Completion(text='MAX', start_position=-2), + Completion(text='MATCH', start_position=-2), + Completion(text='MAKEDATE', start_position=-2), + Completion(text='MAKETIME', start_position=-2), + Completion(text='MAKE_SET', start_position=-2), + Completion(text='MASTER_POS_WAIT', start_position=-2), + Completion(text='email', start_position=-2), + ] + + +def test_suggested_column_names(completer, complete_event): + """Suggest column and function names when selecting from table. + + :param completer: + :param complete_event: + :return: + + """ + text = "SELECT from users" + position = len("SELECT ") + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) + assert result == list( + [ + Completion(text="*", start_position=0), + Completion(text="id", start_position=0), + Completion(text="email", start_position=0), + Completion(text="first_name", start_position=0), + Completion(text="last_name", start_position=0), + ] + + list(map(Completion, completer.functions)) + + [Completion(text="users", start_position=0)] + ) + + +def test_suggested_column_names_empty_db(empty_completer, complete_event): + """Suggest * and function when selecting from no-table db. + + :param empty_completer: + :param complete_event: + :return: + + """ + text = "SELECT " + position = len("SELECT ") + result = list(empty_completer.get_completions(Document(text=text, cursor_position=position), complete_event)) + assert result == list( + [ + Completion(text="*", start_position=0), + ] + + list(map(Completion, empty_completer.functions)) + ) + + +def test_suggested_column_names_in_function(completer, complete_event): + """Suggest column and function names when selecting multiple columns from + table. + + :param completer: + :param complete_event: + :return: + + """ + text = "SELECT MAX( from users" + position = len("SELECT MAX(") + result = completer.get_completions(Document(text=text, cursor_position=position), complete_event) + assert list(result) == [ + Completion(text="*", start_position=0), + Completion(text="id", start_position=0), + Completion(text="email", start_position=0), + Completion(text="first_name", start_position=0), + Completion(text="last_name", start_position=0), + ] + + +def test_suggested_column_names_with_table_dot(completer, complete_event): + """Suggest column names on table name and dot. + + :param completer: + :param complete_event: + :return: + + """ + text = "SELECT users. from users" + position = len("SELECT users.") + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) + assert result == [ + Completion(text="*", start_position=0), + Completion(text="id", start_position=0), + Completion(text="email", start_position=0), + Completion(text="first_name", start_position=0), + Completion(text="last_name", start_position=0), + ] + + +def test_suggested_column_names_with_alias(completer, complete_event): + """Suggest column names on table alias and dot. + + :param completer: + :param complete_event: + :return: + + """ + text = "SELECT u. from users u" + position = len("SELECT u.") + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) + assert result == [ + Completion(text="*", start_position=0), + Completion(text="id", start_position=0), + Completion(text="email", start_position=0), + Completion(text="first_name", start_position=0), + Completion(text="last_name", start_position=0), + ] + + +def test_suggested_multiple_column_names(completer, complete_event): + """Suggest column and function names when selecting multiple columns from + table. + + :param completer: + :param complete_event: + :return: + + """ + text = "SELECT id, from users u" + position = len("SELECT id, ") + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) + assert result == list( + [ + Completion(text="*", start_position=0), + Completion(text="id", start_position=0), + Completion(text="email", start_position=0), + Completion(text="first_name", start_position=0), + Completion(text="last_name", start_position=0), + ] + + list(map(Completion, completer.functions)) + + [Completion(text="u", start_position=0)] + ) + + +def test_suggested_multiple_column_names_with_alias(completer, complete_event): + """Suggest column names on table alias and dot when selecting multiple + columns from table. + + :param completer: + :param complete_event: + :return: + + """ + text = "SELECT u.id, u. from users u" + position = len("SELECT u.id, u.") + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) + assert result == [ + Completion(text="*", start_position=0), + Completion(text="id", start_position=0), + Completion(text="email", start_position=0), + Completion(text="first_name", start_position=0), + Completion(text="last_name", start_position=0), + ] + + +def test_suggested_multiple_column_names_with_dot(completer, complete_event): + """Suggest column names on table names and dot when selecting multiple + columns from table. + + :param completer: + :param complete_event: + :return: + + """ + text = "SELECT users.id, users. from users u" + position = len("SELECT users.id, users.") + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) + assert result == [ + Completion(text="*", start_position=0), + Completion(text="id", start_position=0), + Completion(text="email", start_position=0), + Completion(text="first_name", start_position=0), + Completion(text="last_name", start_position=0), + ] + + +def test_suggested_aliases_after_on(completer, complete_event): + text = "SELECT u.name, o.id FROM users u JOIN orders o ON " + position = len("SELECT u.name, o.id FROM users u JOIN orders o ON ") + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) + assert result == [ + Completion(text="u", start_position=0), + Completion(text="o", start_position=0), + ] + + +def test_suggested_aliases_after_on_right_side(completer, complete_event): + text = "SELECT u.name, o.id FROM users u JOIN orders o ON o.user_id = " + position = len("SELECT u.name, o.id FROM users u JOIN orders o ON o.user_id = ") + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) + assert result == [ + Completion(text="u", start_position=0), + Completion(text="o", start_position=0), + ] + + +def test_suggested_tables_after_on(completer, complete_event): + text = "SELECT users.name, orders.id FROM users JOIN orders ON " + position = len("SELECT users.name, orders.id FROM users JOIN orders ON ") + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) + assert result == [ + Completion(text="users", start_position=0), + Completion(text="orders", start_position=0), + ] + + +def test_suggested_tables_after_on_right_side(completer, complete_event): + text = "SELECT users.name, orders.id FROM users JOIN orders ON orders.user_id = " + position = len("SELECT users.name, orders.id FROM users JOIN orders ON orders.user_id = ") + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) + assert result == [ + Completion(text="users", start_position=0), + Completion(text="orders", start_position=0), + ] + + +def test_table_names_after_from(completer, complete_event): + text = "SELECT * FROM " + position = len("SELECT * FROM ") + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) + assert result == [ + Completion(text="users", start_position=0), + Completion(text="orders", start_position=0), + Completion(text="`select`", start_position=0), + Completion(text="`réveillé`", start_position=0), + Completion(text="time_zone", start_position=0), + Completion(text="time_zone_leap_second", start_position=0), + Completion(text="time_zone_name", start_position=0), + Completion(text="time_zone_transition", start_position=0), + Completion(text="time_zone_transition_type", start_position=0), + Completion(text="test", start_position=0), + Completion(text="`test 2`", start_position=0), + ] + + +def test_table_names_leading_partial(completer, complete_event): + text = "SELECT * FROM time_zone" + position = len("SELECT * FROM time_zone") + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) + assert result == [ + Completion(text="time_zone", start_position=-9), + Completion(text="time_zone_name", start_position=-9), + Completion(text="time_zone_transition", start_position=-9), + Completion(text="time_zone_leap_second", start_position=-9), + Completion(text="time_zone_transition_type", start_position=-9), + ] + + +def test_table_names_inter_partial(completer, complete_event): + text = "SELECT * FROM time_leap" + position = len("SELECT * FROM time_leap") + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) + assert result == [ + Completion(text="time_zone_leap_second", start_position=-9), + Completion(text='time_zone_name', start_position=-9), + Completion(text='time_zone_transition', start_position=-9), + Completion(text='time_zone_transition_type', start_position=-9), + ] + + +def test_table_names_fuzzy(completer, complete_event): + text = "SELECT * FROM tim_leap" + position = len("SELECT * FROM tim_leap") + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) + assert result == [ + Completion(text="time_zone_leap_second", start_position=-8), + ] + + +def test_auto_escaped_col_names(completer, complete_event): + text = "SELECT from `select`" + position = len("SELECT ") + result = [x.text for x in completer.get_completions(Document(text=text, cursor_position=position), complete_event)] + expected = ( + [ + "*", + "id", + "`insert`", + "ABC", + ] + + completer.functions + + ["select"] + ) + assert result == expected + + +def test_un_escaped_table_names(completer, complete_event): + text = "SELECT from réveillé" + position = len("SELECT ") + result = [x.text for x in completer.get_completions(Document(text=text, cursor_position=position), complete_event)] + assert result == [ + "*", + "id", + "`insert`", + "ABC", + ] + completer.functions + ["réveillé"] + + +# todo: the fixtures are insufficient; the database name should also appear in the result +def test_grant_on_suggets_tables_and_schemata(completer, complete_event): + text = "GRANT ALL ON " + position = len(text) + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) + assert result == [ + Completion(text="test", start_position=0), + Completion(text="`test 2`", start_position=0), + Completion(text='users', start_position=0), + Completion(text='orders', start_position=0), + Completion(text='`select`', start_position=0), + Completion(text='`réveillé`', start_position=0), + Completion(text='time_zone', start_position=0), + Completion(text='time_zone_leap_second', start_position=0), + Completion(text='time_zone_name', start_position=0), + Completion(text='time_zone_transition', start_position=0), + Completion(text='time_zone_transition_type', start_position=0), + ] + + +# todo: this test belongs more logically in test_naive_completion.py, but it didn't work there: +# multiple completion candidates were not suggested. +def test_deleted_keyword_completion(completer, complete_event): + text = "exi" + position = len("exi") + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) + assert result == [ + Completion(text="exit", start_position=-3), + Completion(text='exists', start_position=-3), + Completion(text='explain', start_position=-3), + Completion(text='expire', start_position=-3), + ] + + +def test_numbers_no_completion(completer, complete_event): + text = "SELECT COUNT(1) FROM time_zone WHERE Time_zone_id = 1" + position = len("SELECT COUNT(1) FROM time_zone WHERE Time_zone_id = 1") + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) + assert result == [] # ie not INT1 + + +def dummy_list_path(dir_name): + dirs = { + "/": [ + "dir1", + "file1.sql", + "file2.sql", + ], + "/dir1": [ + "subdir1", + "subfile1.sql", + "subfile2.sql", + ], + "/dir1/subdir1": [ + "lastfile.sql", + ], + } + return dirs.get(dir_name, []) + + +@patch("mycli.packages.filepaths.list_path", new=dummy_list_path) +@pytest.mark.parametrize( + "text,expected", + [ + ('source ', [('/', 0), ('~', 0), ('.', 0), ('..', 0)]), + ("source /", [("dir1", 0), ("file1.sql", 0), ("file2.sql", 0)]), + ("source /dir1/", [("subdir1", 0), ("subfile1.sql", 0), ("subfile2.sql", 0)]), + ("source /dir1/subdir1/", [("lastfile.sql", 0)]), + ], +) +def test_file_name_completion(completer, complete_event, text, expected): + position = len(text) + special.register_special_command( + ..., + 'source', + '\\. ', + 'Execute commands from file.', + aliases=[special.SpecialCommandAlias('\\.', case_sensitive=False)], + ) + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) + expected = [Completion(txt, pos) for txt, pos in expected] + assert result == expected + + +def test_auto_case_heuristic(completer, complete_event): + text = "select json_v" + position = len("select json_v") + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) + assert [x.text for x in result] == [ + 'json_value', + 'json_valid', + ] + + +def test_create_table_like_completion(completer, complete_event): + text = "CREATE TABLE foo LIKE ti" + position = len(text) + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) + assert [x.text for x in result] == [ + 'time_zone', + 'time_zone_name', + 'time_zone_transition', + 'time_zone_leap_second', + 'time_zone_transition_type', + ] + + +def test_source_eager_completion(completer, complete_event, tmp_path, monkeypatch): + monkeypatch.chdir(tmp_path) + os.mkdir('doc') + text = "source do" + position = len(text) + script_filename = 'do_these_statements.sql' + f = open(script_filename, 'w') + f.close() + special.register_special_command( + ..., + 'source', + '\\. ', + 'Execute commands from file.', + aliases=[special.SpecialCommandAlias('\\.', case_sensitive=False)], + ) + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) + success = True + error = 'unknown' + try: + assert [x.text for x in result] == [ + script_filename, + 'doc/', + ] + except AssertionError as e: + success = False + error = e + if os.path.exists(script_filename): + os.remove(script_filename) + if not success: + raise AssertionError(error) + + +def test_source_leading_dot_suggestions_completion(completer, complete_event, tmp_path, monkeypatch): + monkeypatch.chdir(tmp_path) + os.mkdir('doc') + text = "source ./do" + position = len(text) + script_filename = 'do_these_statements.sql' + f = open(script_filename, 'w') + f.close() + special.register_special_command( + ..., + 'source', + '\\. ', + 'Execute commands from file.', + aliases=[special.SpecialCommandAlias('\\.', case_sensitive=False)], + ) + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) + success = True + error = 'unknown' + try: + assert [x.text for x in result] == [ + script_filename, + 'doc/', + ] + except AssertionError as e: + success = False + error = e + if os.path.exists(script_filename): + os.remove(script_filename) + if not success: + raise AssertionError(error) + + +def test_string_no_completion(completer, complete_event): + text = 'select "json' + position = len(text) + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) + assert result == [] + + +def test_string_no_completion_single_quote(completer, complete_event): + text = "select 'json" + position = len(text) + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) + assert result == [] + + +def test_string_no_completion_spaces(completer, complete_event): + text = 'select "nocomplete json' + position = len(text) + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) + assert result == [] + + +def test_string_no_completion_spaces_inner_1(completer, complete_event): + text = 'select "json nocomplete' + position = len('select "json') + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) + assert result == [] + + +def test_string_no_completion_spaces_inner_2(completer, complete_event): + text = 'select "json nocomplete' + position = len('select "json ') + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) + assert result == [] + + +def test_backticked_column_completion(completer, complete_event): + text = 'select `Tim' + position = len(text) + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) + assert result == [ + # todo it would be nicer if the column names sorted to the top + Completion(text='`time`', start_position=-4), + Completion(text='`timediff`', start_position=-4), + Completion(text='`timestamp`', start_position=-4), + Completion(text='`time_format`', start_position=-4), + Completion(text='`time_to_sec`', start_position=-4), + Completion(text='`Time_zone_id`', start_position=-4), + Completion(text='`timestampadd`', start_position=-4), + Completion(text='`timestampdiff`', start_position=-4), + Completion(text='`datetime`', start_position=-4), + Completion(text='`optimize`', start_position=-4), + Completion(text='`optimizer_costs`', start_position=-4), + Completion(text='`utc_time`', start_position=-4), + Completion(text='`utc_timestamp`', start_position=-4), + Completion(text='`current_time`', start_position=-4), + Completion(text='`current_timestamp`', start_position=-4), + Completion(text='`localtime`', start_position=-4), + Completion(text='`localtimestamp`', start_position=-4), + Completion(text='`password_lock_time`', start_position=-4), + ] + + +def test_backticked_column_completion_component(completer, complete_event): + text = 'select `com' + position = len(text) + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) + assert result == [ + # todo it would be nicer if "comment" sorted to the top because it is a column name, + # and because it is a reserved word + Completion(text='`commit`', start_position=-4), + Completion(text='`comment`', start_position=-4), + Completion(text='`compact`', start_position=-4), + Completion(text='`compress`', start_position=-4), + Completion(text='`committed`', start_position=-4), + Completion(text='`component`', start_position=-4), + Completion(text='`completion`', start_position=-4), + Completion(text='`compressed`', start_position=-4), + Completion(text='`compression`', start_position=-4), + Completion(text='`column`', start_position=-4), + Completion(text='`column_format`', start_position=-4), + Completion(text='`column_name`', start_position=-4), + Completion(text='`columns`', start_position=-4), + Completion(text='`second_microsecond`', start_position=-4), + Completion(text='`uncommitted`', start_position=-4), + ] + + +def test_backticked_column_completion_two_character(completer, complete_event): + text = 'select `f' + position = len(text) + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) + expected = [ + # todo it would be nicer if the column name "first_name" sorted to the top + Completion(text='`for`', start_position=-2), + Completion(text='`from`', start_position=-2), + Completion(text='`fast`', start_position=-2), + Completion(text='`file`', start_position=-2), + Completion(text='`full`', start_position=-2), + Completion(text='`floor`', start_position=-2), + Completion(text='`false`', start_position=-2), + Completion(text='`field`', start_position=-2), + Completion(text='`fixed`', start_position=-2), + Completion(text='`float`', start_position=-2), + Completion(text='`fetch`', start_position=-2), + Completion(text='`files`', start_position=-2), + Completion(text='`first`', start_position=-2), + Completion(text='`flush`', start_position=-2), + Completion(text='`force`', start_position=-2), + Completion(text='`found`', start_position=-2), + Completion(text='`format`', start_position=-2), + Completion(text='`float4`', start_position=-2), + Completion(text='`float8`', start_position=-2), + Completion(text='`factor`', start_position=-2), + Completion(text='`faults`', start_position=-2), + Completion(text='`fields`', start_position=-2), + Completion(text='`filter`', start_position=-2), + Completion(text='`finish`', start_position=-2), + Completion(text='`follows`', start_position=-2), + Completion(text='`foreign`', start_position=-2), + Completion(text='`fulltext`', start_position=-2), + Completion(text='`function`', start_position=-2), + Completion(text='`from_days`', start_position=-2), + Completion(text='`file_name`', start_position=-2), + Completion(text='`following`', start_position=-2), + Completion(text='`first_name`', start_position=-2), + Completion(text='`found_rows`', start_position=-2), + Completion(text='`find_in_set`', start_position=-2), + Completion(text='`first_value`', start_position=-2), + Completion(text='`from_base64`', start_position=-2), + Completion(text='`from_vector`', start_position=-2), + Completion(text='`file_format`', start_position=-2), + Completion(text='`file_prefix`', start_position=-2), + Completion(text='`foreign key`', start_position=-2), + Completion(text='`format_bytes`', start_position=-2), + Completion(text='`file_pattern`', start_position=-2), + Completion(text='`from_unixtime`', start_position=-2), + Completion(text='`file_block_size`', start_position=-2), + Completion(text='`format_pico_time`', start_position=-2), + Completion(text='`failed_login_attempts`', start_position=-2), + Completion(text='`left join`', start_position=-2), + Completion(text='`after`', start_position=-2), + Completion(text='`before`', start_position=-2), + Completion(text='`default`', start_position=-2), + Completion(text='`default_auth`', start_position=-2), + Completion(text='`definer`', start_position=-2), + Completion(text='`definition`', start_position=-2), + Completion(text='`enforced`', start_position=-2), + Completion(text='`if`', start_position=-2), + Completion(text='`infile`', start_position=-2), + Completion(text='`left`', start_position=-2), + Completion(text='`logfile`', start_position=-2), + Completion(text='`of`', start_position=-2), + Completion(text='`off`', start_position=-2), + Completion(text='`offset`', start_position=-2), + Completion(text='`outfile`', start_position=-2), + Completion(text='`profile`', start_position=-2), + Completion(text='`profiles`', start_position=-2), + Completion(text='`reference`', start_position=-2), + Completion(text='`references`', start_position=-2), + ] + + if pygments_below("2.20"): + for newer in [ + Completion(text='`file_format`', start_position=-2), + Completion(text='`file_name`', start_position=-2), + Completion(text='`file_pattern`', start_position=-2), + Completion(text='`file_prefix`', start_position=-2), + Completion(text='`files`', start_position=-2), + Completion(text='`from_vector`', start_position=-2), + ]: + expected.remove(newer) + + assert result == expected + + +def test_backticked_column_completion_three_character(completer, complete_event): + text = 'select `fi' + position = len(text) + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) + expected = [ + # todo it would be nicer if the column name "first_name" sorted to the top + Completion(text='`file`', start_position=-3), + Completion(text='`field`', start_position=-3), + Completion(text='`fixed`', start_position=-3), + Completion(text='`files`', start_position=-3), + Completion(text='`first`', start_position=-3), + Completion(text='`fields`', start_position=-3), + Completion(text='`filter`', start_position=-3), + Completion(text='`finish`', start_position=-3), + Completion(text='`file_name`', start_position=-3), + Completion(text='`first_name`', start_position=-3), + Completion(text='`find_in_set`', start_position=-3), + Completion(text='`first_value`', start_position=-3), + Completion(text='`file_format`', start_position=-3), + Completion(text='`file_prefix`', start_position=-3), + Completion(text='`file_pattern`', start_position=-3), + Completion(text='`file_block_size`', start_position=-3), + Completion(text='`definer`', start_position=-3), + Completion(text='`definition`', start_position=-3), + Completion(text='`failed_login_attempts`', start_position=-3), + Completion(text='`foreign`', start_position=-3), + Completion(text='`infile`', start_position=-3), + Completion(text='`logfile`', start_position=-3), + Completion(text='`outfile`', start_position=-3), + Completion(text='`profile`', start_position=-3), + Completion(text='`profiles`', start_position=-3), + Completion(text='`foreign key`', start_position=-3), + ] + + if pygments_below("2.20"): + for newer in [ + Completion(text='`files`', start_position=-3), + Completion(text='`file_name`', start_position=-3), + Completion(text='`file_format`', start_position=-3), + Completion(text='`file_pattern`', start_position=-3), + Completion(text='`file_prefix`', start_position=-3), + ]: + expected.remove(newer) + + assert result == expected + + +def test_backticked_column_completion_four_character(completer, complete_event): + text = 'select `fir' + position = len(text) + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) + assert result == [ + # todo it would be nicer if the column name "first_name" sorted to the top + Completion(text='`first`', start_position=-4), + Completion(text='`first_name`', start_position=-4), + Completion(text='`first_value`', start_position=-4), + Completion(text='`definer`', start_position=-4), + Completion(text='`filter`', start_position=-4), + ] + + +def test_backticked_table_completion_required(completer, complete_event): + text = 'select ABC from `rév' + position = len(text) + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) + assert result == [ + Completion(text='`réveillé`', start_position=-4), + ] + + +def test_backticked_table_completion_not_required(completer, complete_event): + text = 'select * from `t' + position = len(text) + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) + assert result == [ + Completion(text='`test`', start_position=-2), + Completion(text='`test 2`', start_position=-2), + Completion(text='`time_zone`', start_position=-2), + Completion(text='`time_zone_name`', start_position=-2), + Completion(text='`time_zone_transition`', start_position=-2), + Completion(text='`time_zone_leap_second`', start_position=-2), + Completion(text='`time_zone_transition_type`', start_position=-2), + ] + + +def test_string_no_completion_backtick(completer, complete_event): + text = 'select * from "`t' + position = len(text) + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) + assert result == [] + + +# todo this shouldn't suggest anything but the space resets the logic +# and it completes on "bar" alone +@pytest.mark.xfail +def test_backticked_no_completion_spaces(completer, complete_event): + text = 'select * from `nocomplete bar' + position = len(text) + result = list(completer.get_completions(Document(text=text, cursor_position=position), complete_event)) + assert result == [] + + +# Foreign key completion tests +@pytest.fixture +def fk_completer(): + """SQLCompleter with tables and a FK relationship. + + Schema: + orders (id, user_id, ordered_date, status) FK: user_id -> users.id + users (id, email, first_name) + tags (id, name) no FK + """ + import mycli.packages.special.main as special + import mycli.sqlcompleter as sqlcompleter + + comp = sqlcompleter.SQLCompleter(smart_completion=True) + + tables = [("orders",), ("users",), ("tags",)] + columns = [ + ("orders", "id"), + ("orders", "user_id"), + ("orders", "ordered_date"), + ("orders", "status"), + ("users", "id"), + ("users", "email"), + ("users", "first_name"), + ("tags", "id"), + ("tags", "name"), + ] + fk_data = [("orders", "user_id", "users", "id")] + + comp.extend_schemata("test") + comp.extend_database_names(["test"]) + comp.set_dbname("test") + comp.extend_relations(tables, kind="tables") + comp.extend_columns(columns, kind="tables") + comp.extend_foreign_keys(fk_data) + comp.extend_special_commands(special.COMMANDS) + + return comp + + +def test_extend_foreign_keys_stores_relation(fk_completer): + relations = fk_completer.dbmetadata["foreign_keys"]["test"]["relations"] + assert ("orders", "user_id", "users", "id") in relations + + +def test_extend_foreign_keys_stores_bidirectional_table_map(fk_completer): + tables_map = fk_completer.dbmetadata["foreign_keys"]["test"]["tables"] + assert "users" in tables_map["orders"] + assert "orders" in tables_map["users"] + + +def test_extend_foreign_keys_unrelated_table_absent_from_map(fk_completer): + tables_map = fk_completer.dbmetadata["foreign_keys"]["test"]["tables"] + assert "tags" not in tables_map + + +def test_fk_join_conditions_with_aliases(fk_completer): + conditions = fk_completer._fk_join_conditions([(None, "orders", "o"), (None, "users", "u")]) + assert conditions == ["o.user_id = u.id"] + + +def test_fk_join_conditions_without_aliases(fk_completer): + conditions = fk_completer._fk_join_conditions([(None, "orders", None), (None, "users", None)]) + assert conditions == ["orders.user_id = users.id"] + + +def test_fk_join_conditions_single_table_yields_nothing(fk_completer): + conditions = fk_completer._fk_join_conditions([(None, "orders", "o")]) + assert conditions == [] + + +def test_fk_join_conditions_unrelated_tables_yields_nothing(fk_completer): + conditions = fk_completer._fk_join_conditions([(None, "orders", "o"), (None, "tags", "t")]) + assert conditions == [] + + +def test_join_suggests_fk_table_before_unrelated(fk_completer, complete_event): + text = "SELECT * FROM orders JOIN " + result = [c.text for c in fk_completer.get_completions(Document(text=text, cursor_position=len(text)), complete_event)] + assert "users" in result + assert "tags" in result + assert result.index("users") < result.index("tags") + + +def test_join_fk_lookup_is_bidirectional(fk_completer, complete_event): + text = "SELECT * FROM users JOIN " + result = [c.text for c in fk_completer.get_completions(Document(text=text, cursor_position=len(text)), complete_event)] + assert "orders" in result + assert "tags" in result + assert result.index("orders") < result.index("tags") + + +def test_join_unrelated_table_still_suggests_all_tables(fk_completer, complete_event): + text = "SELECT * FROM tags JOIN " + result = [c.text for c in fk_completer.get_completions(Document(text=text, cursor_position=len(text)), complete_event)] + assert "orders" in result + assert "users" in result + + +def test_on_suggests_fk_condition_with_aliases(fk_completer, complete_event): + text = "SELECT * FROM orders o JOIN users u ON " + result = [c.text for c in fk_completer.get_completions(Document(text=text, cursor_position=len(text)), complete_event)] + assert "o.user_id = u.id" in result + + +def test_on_suggests_fk_condition_without_aliases(fk_completer, complete_event): + text = "SELECT * FROM orders JOIN users ON " + result = [c.text for c in fk_completer.get_completions(Document(text=text, cursor_position=len(text)), complete_event)] + assert "orders.user_id = users.id" in result + + +def test_on_fk_condition_appears_before_aliases(fk_completer, complete_event): + text = "SELECT * FROM orders o JOIN users u ON " + result = [c.text for c in fk_completer.get_completions(Document(text=text, cursor_position=len(text)), complete_event)] + assert result.index("o.user_id = u.id") < result.index("o") + + +def test_on_no_fk_condition_for_unrelated_join(fk_completer, complete_event): + text = "SELECT * FROM orders o JOIN tags t ON " + result = [c.text for c in fk_completer.get_completions(Document(text=text, cursor_position=len(text)), complete_event)] + assert not any("=" in r for r in result) + assert "o" in result + assert "t" in result + + +def test_on_partial_text_filters_fk_condition(fk_completer, complete_event): + text = "SELECT * FROM orders JOIN users ON ord" + result = [c.text for c in fk_completer.get_completions(Document(text=text, cursor_position=len(text)), complete_event)] + assert "orders.user_id = users.id" in result + + +def test_fk_reserved_column_names_are_escaped(): + """FK columns that are reserved words or need quoting must be backtick-escaped.""" + import mycli.sqlcompleter as sqlcompleter + + comp = sqlcompleter.SQLCompleter(smart_completion=True) + comp.extend_schemata("test") + comp.set_dbname("test") + comp.extend_foreign_keys([("orders", "order", "users", "select")]) + + relations = comp.dbmetadata["foreign_keys"]["test"]["relations"] + assert ("orders", "`order`", "users", "`select`") in relations + + conditions = comp._fk_join_conditions([(None, "orders", "o"), (None, "users", "u")]) + assert conditions == ["o.`order` = u.`select`"] + + +def test_fk_conditions_ignore_cross_schema_tables(fk_completer): + """Tables qualified with a foreign schema are excluded from FK condition generation.""" + tables = [("other_db", "orders", "o"), (None, "users", "u")] + conditions = fk_completer._fk_join_conditions(tables) + assert conditions == [] + + +def test_join_priority_ignores_cross_schema_table(fk_completer, complete_event): + """Schema-qualified tables in FROM do not trigger FK priority using current-db metadata.""" + text = "SELECT * FROM other_db.orders JOIN " + result_cross_schema = [c.text for c in fk_completer.get_completions(Document(text=text, cursor_position=len(text)), complete_event)] + # A table with no FK relationships at all should give the same ordering, + # confirming that no FK priority was applied for the cross-schema table. + text_no_fk = "SELECT * FROM tags JOIN " + result_no_fk = [ + c.text for c in fk_completer.get_completions(Document(text=text_no_fk, cursor_position=len(text_no_fk)), complete_event) + ] + assert result_cross_schema == result_no_fk diff --git a/test/pytests/test_special_dbcommands.py b/test/pytests/test_special_dbcommands.py new file mode 100644 index 00000000..2859e654 --- /dev/null +++ b/test/pytests/test_special_dbcommands.py @@ -0,0 +1,332 @@ +# type: ignore + +from unittest.mock import MagicMock + +from pymysql import ProgrammingError + +from mycli.packages.completion_engine import suggest_type +from mycli.packages.special import dbcommands +from mycli.packages.special.dbcommands import list_databases, list_tables, status +from test.pytests.test_completion_engine import sorted_dicts + + +class FakeConnection: + def __init__( + self, + *, + host: str = 'db.example', + port: int = 3306, + host_info: str = 'Localhost via UNIX socket', + unix_socket: str | None = None, + thread_id_value: int = 42, + ) -> None: + self.host = host + self.port = port + self.host_info = host_info + self.unix_socket = unix_socket + self._thread_id_value = thread_id_value + + def thread_id(self) -> int: + return self._thread_id_value + + +class FakeCursor: + def __init__( + self, + *, + query_results: dict[str, dict[str, object]], + connection: FakeConnection | None = None, + fail_on_queries: set[str] | None = None, + ) -> None: + self.query_results = query_results + self.connection = connection or FakeConnection() + self.fail_on_queries = fail_on_queries or set() + self.description = None + self.current_query = None + self.executed: list[str] = [] + + def execute(self, query: str) -> None: + self.executed.append(query) + self.current_query = query + if query in self.fail_on_queries: + raise ProgrammingError() + self.description = self.query_results.get(query, {}).get('description') + + def fetchall(self): + return self.query_results.get(self.current_query, {}).get('rows', []) + + def fetchone(self): + rows = self.query_results.get(self.current_query, {}).get('rows', []) + return rows[0] if rows else None + + +def test_list_tables_verbose_preserves_field_results(): + """Test that \\dt+ table_name returns SHOW FIELDS results, not SHOW CREATE TABLE results. + + This is a regression test for a bug where the cursor was reused for SHOW CREATE TABLE, + which overwrote the SHOW FIELDS results. + """ + # Mock cursor that simulates MySQL behavior + cur = MagicMock() + + # Track which query is being executed + query_results = { + 'SHOW FIELDS FROM test_table': { + 'description': [('Field',), ('Type',), ('Null',), ('Key',), ('Default',), ('Extra',)], + 'rows': [ + ('id', 'int', 'NO', 'PRI', None, 'auto_increment'), + ('name', 'varchar(255)', 'YES', '', None, ''), + ], + }, + 'SHOW CREATE TABLE test_table': { + 'description': [('Table',), ('Create Table',)], + 'rows': [('test_table', 'CREATE TABLE `test_table` ...')], + }, + } + + current_query = [None] # Use list to allow mutation in nested function + + def execute_side_effect(query): + current_query[0] = query + cur.description = query_results[query]['description'] + cur.rowcount = len(query_results[query]['rows']) + + def fetchall_side_effect(): + return query_results[current_query[0]]['rows'] + + def fetchone_side_effect(): + rows = query_results[current_query[0]]['rows'] + return rows[0] if rows else None + + cur.execute.side_effect = execute_side_effect + cur.fetchall.side_effect = fetchall_side_effect + cur.fetchone.side_effect = fetchone_side_effect + + # Call list_tables with command_verbosity=True (simulating \dt+ table_name) + results = list_tables(cur, arg='test_table', command_verbosity=True) + + assert len(results) == 1 + result = results[0] + + # The header should be from SHOW FIELDS + assert result.header == ['Field', 'Type', 'Null', 'Key', 'Default', 'Extra'] + + # The results should contain the field data, not be empty + # Convert to list if it's a cursor or iterable + result_data = list(result.rows) if hasattr(result.rows, '__iter__') else result.rows + assert len(result_data) == 2 + assert result_data[0][0] == 'id' + assert result_data[1][0] == 'name' + + # The postamble should contain the CREATE TABLE statement + assert 'CREATE TABLE' in result.postamble + + +def test_u_suggests_databases(): + suggestions = suggest_type("\\u ", "\\u ") + assert sorted_dicts(suggestions) == sorted_dicts([{"type": "database"}]) + + +def test_describe_table(): + suggestions = suggest_type("\\dt", "\\dt ") + assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}]) + + +def test_list_or_show_create_tables(): + suggestions = suggest_type("\\dt+", "\\dt+ ") + assert sorted_dicts(suggestions) == sorted_dicts([{"type": "table", "schema": []}, {"type": "view", "schema": []}, {"type": "schema"}]) + + +def test_list_tables_nonverbose_and_empty_result() -> None: + cursor = FakeCursor( + query_results={ + 'SHOW TABLES': { + 'description': [('Tables_in_test',)], + }, + 'SHOW FIELDS FROM missing_table': { + 'description': None, + }, + } + ) + + listed = list_tables(cursor) + assert listed[0].header == ['Tables_in_test'] + assert listed[0].rows is cursor + + described = list_tables(cursor, arg='missing_table') + assert described[0].header is None + assert described[0].rows is None + + +def test_list_databases_with_and_without_description() -> None: + cursor = FakeCursor( + query_results={ + 'SHOW DATABASES': { + 'description': [('Database',)], + }, + } + ) + + listed = list_databases(cursor) + assert listed[0].header == ['Database'] + assert listed[0].rows is cursor + + empty_cursor = FakeCursor(query_results={'SHOW DATABASES': {'description': None}}) + empty = list_databases(empty_cursor) + assert empty[0].header is None + assert empty[0].rows is None + + +def test_status_uses_global_queries_decodes_bytes_and_formats_stats(monkeypatch) -> None: + monkeypatch.setattr(dbcommands, '__version__', '9.9.9') + monkeypatch.setattr(dbcommands.platform, 'python_implementation', lambda: 'CPython') + monkeypatch.setattr(dbcommands.platform, 'python_version', lambda: '3.14.0') + monkeypatch.setattr(dbcommands.iocommands, 'is_pager_enabled', lambda: True) + monkeypatch.setattr(dbcommands, 'get_ssl_cipher', lambda cur: 'TLS_AES_256_GCM_SHA384') + monkeypatch.setattr(dbcommands, 'get_ssl_version', lambda cur: 'TLSv1.3') + monkeypatch.setattr(dbcommands, 'format_uptime', lambda uptime: f'{uptime} seconds') + monkeypatch.setenv('PAGER', 'less -SR') + + cursor = FakeCursor( + connection=FakeConnection(host='tcp-host', port=3307, unix_socket=None), + query_results={ + 'SHOW GLOBAL STATUS;': { + 'rows': [ + (b'Uptime', b'10'), + (b'Threads_connected', b'5'), + (b'Queries', b'20'), + (b'Slow_queries', b'1'), + (b'Opened_tables', b'2'), + (b'Flush_commands', b'3'), + (b'Open_tables', b'4'), + ], + }, + 'SHOW GLOBAL VARIABLES;': { + 'rows': [ + (b'version', b'8.0.0'), + (b'version_comment', b'Community'), + (b'protocol_version', b'10'), + ], + }, + 'SELECT DATABASE(), USER();': { + 'rows': [('test_db', 'test_user')], + }, + 'SHOW SESSION VARIABLES;': { + 'rows': [ + (b'character_set_server', b'utf8mb4'), + (b'character_set_database', b'utf8mb4'), + (b'character_set_client', b'utf8mb4'), + (b'character_set_connection', b'utf8mb4'), + (b'character_set_results', b'utf8mb4'), + ], + }, + }, + ) + + result = status(cursor)[0] + + assert 'mycli 9.9.9 running on CPython 3.14.0' in result.preamble + assert ('Connection id:', 42) in result.rows + assert ('Current database:', 'test_db') in result.rows + assert ('Current user:', 'test_user') in result.rows + assert ('Current pager:', 'less -SR') in result.rows + assert ('Server version:', '8.0.0 Community') in result.rows + assert ('Protocol version:', '10') in result.rows + assert ('SSL:', 'Cipher in use is TLS_AES_256_GCM_SHA384') in result.rows + assert ('SSL/TLS version:', 'TLSv1.3') in result.rows + assert ('Connection:', 'tcp-host via TCP/IP') in result.rows + assert ('TCP port:', 3307) in result.rows + assert ('Uptime:', '10 seconds') in result.rows + assert 'Connections: 5' in result.postamble + assert 'Queries: 20' in result.postamble + assert 'Slow queries: 1' in result.postamble + assert 'Flush tables: 3' in result.postamble + assert 'Queries per second avg: 2.000' in result.postamble + assert '--------------' in result.postamble + + +def test_status_falls_back_to_show_status_and_handles_empty_selects(monkeypatch) -> None: + monkeypatch.setattr(dbcommands, '__version__', '1.0.0') + monkeypatch.setattr(dbcommands.platform, 'python_implementation', lambda: 'CPython') + monkeypatch.setattr(dbcommands.platform, 'python_version', lambda: '3.10.0') + monkeypatch.setattr(dbcommands.iocommands, 'is_pager_enabled', lambda: False) + monkeypatch.setattr(dbcommands, 'get_ssl_version', lambda cur: 'none') + monkeypatch.setattr(dbcommands, 'format_uptime', lambda uptime: f'{uptime} seconds') + + cursor = FakeCursor( + connection=FakeConnection(unix_socket='/tmp/mysql.sock'), + fail_on_queries={'SHOW GLOBAL STATUS;'}, + query_results={ + 'SHOW STATUS;': { + 'rows': [ + ('Slow_queries', '0'), + ('Opened_tables', '1'), + ('Open_tables', '2'), + ], + }, + 'SHOW GLOBAL VARIABLES;': { + 'rows': [ + ('version', '5.7.0'), + ('version_comment', 'Server'), + ('protocol_version', '10'), + ('socket', '/tmp/mysql.sock'), + ], + }, + 'SHOW SESSION VARIABLES;': { + 'rows': [], + }, + 'SELECT DATABASE(), USER();': { + 'rows': [], + }, + }, + ) + + result = status(cursor)[0] + + assert cursor.executed[0:2] == ['SHOW GLOBAL STATUS;', 'SHOW STATUS;'] + assert ('Current database:', '') in result.rows + assert ('Current user:', '') in result.rows + assert ('Current pager:', 'stdout') in result.rows + assert ('Connection:', 'Localhost via UNIX socket') in result.rows + assert ('UNIX socket:', '/tmp/mysql.sock') in result.rows + assert ('Server characterset:', '') in result.rows + assert ('Db characterset:', '') in result.rows + assert ('Client characterset:', '') in result.rows + assert ('Conn. characterset:', '') in result.rows + assert 'Connections:' not in result.postamble + assert '--------------' in result.postamble + + +def test_status_uses_system_default_pager_when_enabled_without_env(monkeypatch) -> None: + monkeypatch.setattr(dbcommands.iocommands, 'is_pager_enabled', lambda: True) + monkeypatch.setattr(dbcommands, 'get_ssl_version', lambda cur: 'TLS') + monkeypatch.setattr(dbcommands.platform, 'python_implementation', lambda: 'CPython') + monkeypatch.setattr(dbcommands.platform, 'python_version', lambda: '3.14.0') + monkeypatch.delenv('PAGER', raising=False) + + cursor = FakeCursor( + query_results={ + 'SHOW GLOBAL STATUS;': { + 'rows': [('Slow_queries', '0'), ('Opened_tables', '1'), ('Open_tables', '2')], + }, + 'SHOW GLOBAL VARIABLES;': { + 'rows': [('version', '8.0'), ('version_comment', 'Server'), ('protocol_version', '10')], + }, + 'SELECT DATABASE(), USER();': { + 'rows': [('db', 'user')], + }, + 'SHOW SESSION VARIABLES;': { + 'rows': [ + ('character_set_server', 'utf8'), + ('character_set_database', 'utf8'), + ('character_set_client', 'utf8'), + ('character_set_connection', 'utf8'), + ('character_set_results', 'utf8'), + ], + }, + }, + ) + + result = status(cursor)[0] + + assert ('Current pager:', 'System default') in result.rows diff --git a/test/pytests/test_special_iocommands.py b/test/pytests/test_special_iocommands.py new file mode 100644 index 00000000..c4f6a53e --- /dev/null +++ b/test/pytests/test_special_iocommands.py @@ -0,0 +1,931 @@ +# type: ignore + +import builtins +import os +from pathlib import Path +import stat +import subprocess +import tempfile +from time import time +from types import SimpleNamespace +from typing import Any, Generator +from unittest.mock import patch + +from pymysql import ProgrammingError +import pytest + +import mycli.packages.special +from mycli.packages.special import iocommands +from mycli.packages.sqlresult import SQLResult +from test.utils import TEMPFILE_PREFIX, db_connection, dbtest, send_ctrl_c + + +class FakeFavoriteQueries: + usage = '\nFAKE FAVORITES' + + def __init__(self, queries: dict[str, str] | None = None) -> None: + self.queries = {} if queries is None else dict(queries) + self.saved: list[tuple[str, str]] = [] + self.deleted: list[str] = [] + + def list(self) -> list[str]: + return list(self.queries) + + def get(self, name: str) -> str | None: + return self.queries.get(name) + + def save(self, name: str, query: str) -> None: + self.saved.append((name, query)) + self.queries[name] = query + + def delete(self, name: str) -> str: + self.deleted.append(name) + return f'{name}: Deleted.' + + +class FakeCursor: + def __init__(self, descriptions: dict[str, list[tuple[str]] | None] | None = None) -> None: + self.descriptions = {} if descriptions is None else descriptions + self.description: list[tuple[str]] | None = None + self.executed: list[str] = [] + + def execute(self, sql: str) -> None: + self.executed.append(sql) + self.description = self.descriptions.get(sql) + + +class SequenceCursor: + def __init__(self, descriptions: list[list[tuple[str]] | None]) -> None: + self.descriptions = descriptions + self.description: list[tuple[str]] | None = None + self.executed: list[str] = [] + + def execute(self, sql: str) -> None: + self.executed.append(sql) + self.description = self.descriptions.pop(0) + + +class FakeProcess: + def __init__( + self, + *, + stdout: bytes | str = b'', + stderr: bytes | str = b'', + returncode: int = 0, + raise_timeout: bool = False, + ) -> None: + self.stdout = stdout + self.stderr = stderr + self.returncode = returncode + self.raise_timeout = raise_timeout + self.communicate_calls = 0 + self.communicate_timeouts: list[int | None] = [] + self.killed = False + + def communicate(self, input: str | None = None, timeout: int | None = None) -> tuple[bytes | str, bytes | str]: # noqa: A002 + self.communicate_calls += 1 + self.communicate_timeouts.append(timeout) + if self.raise_timeout and self.communicate_calls == 1: + raise subprocess.TimeoutExpired(cmd='fake', timeout=timeout or 0) + return (self.stdout, self.stderr) + + def kill(self) -> None: + self.killed = True + + +@pytest.fixture(autouse=True) +def reset_iocommands_state(monkeypatch) -> Generator[None, None, None]: + original_timing = iocommands.TIMING_ENABLED + original_pager = iocommands.PAGER_ENABLED + original_show_favorite = iocommands.SHOW_FAVORITE_QUERY + original_force_horizontal = iocommands.force_horizontal_output + original_destructive_keywords = list(iocommands.DESTRUCTIVE_KEYWORDS) + original_once_file = iocommands.once_file + original_tee_file = iocommands.tee_file + original_written = iocommands.written_to_once_file + original_pipe_once = dict(iocommands.PIPE_ONCE) + original_favoritequeries = iocommands.favoritequeries + had_instance = hasattr(iocommands.FavoriteQueries, 'instance') + original_instance = getattr(iocommands.FavoriteQueries, 'instance', None) + + yield + + if iocommands.once_file and iocommands.once_file is not original_once_file: + iocommands.once_file.close() + if iocommands.tee_file and iocommands.tee_file is not original_tee_file: + iocommands.tee_file.close() + + iocommands.TIMING_ENABLED = original_timing + iocommands.PAGER_ENABLED = original_pager + iocommands.SHOW_FAVORITE_QUERY = original_show_favorite + iocommands.force_horizontal_output = original_force_horizontal + iocommands.DESTRUCTIVE_KEYWORDS = original_destructive_keywords + iocommands.once_file = original_once_file + iocommands.tee_file = original_tee_file + iocommands.written_to_once_file = original_written + iocommands.PIPE_ONCE.clear() + iocommands.PIPE_ONCE.update(original_pipe_once) + iocommands.favoritequeries = original_favoritequeries + if had_instance: + iocommands.FavoriteQueries.instance = original_instance + + +@pytest.fixture +def favorite_queries_instance(monkeypatch) -> None: + monkeypatch.setattr(iocommands.FavoriteQueries, 'instance', iocommands.favoritequeries, raising=False) + + +def test_set_get_pager(monkeypatch): + monkeypatch.setenv('PAGER', '') + mycli.packages.special.set_pager_enabled(True) + assert mycli.packages.special.is_pager_enabled() + mycli.packages.special.set_pager_enabled(False) + assert not mycli.packages.special.is_pager_enabled() + mycli.packages.special.set_pager("less") + assert os.environ["PAGER"] == "less" + mycli.packages.special.set_pager(False) + assert os.environ["PAGER"] == "less" + del os.environ["PAGER"] + mycli.packages.special.set_pager(False) + mycli.packages.special.disable_pager() + assert not mycli.packages.special.is_pager_enabled() + + +def test_set_get_timing(): + mycli.packages.special.set_timing_enabled(True) + assert mycli.packages.special.is_timing_enabled() + mycli.packages.special.set_timing_enabled(False) + assert not mycli.packages.special.is_timing_enabled() + + +def test_set_get_expanded_output(): + mycli.packages.special.set_expanded_output(True) + assert mycli.packages.special.is_expanded_output() + mycli.packages.special.set_expanded_output(False) + assert not mycli.packages.special.is_expanded_output() + + +def test_editor_command(monkeypatch): + monkeypatch.setenv('EDITOR', 'true') + monkeypatch.setenv('VISUAL', 'true') + + assert mycli.packages.special.editor_command(r"hello\e") + assert mycli.packages.special.editor_command(r"hello\edit") + assert mycli.packages.special.editor_command(r"\e hello") + assert mycli.packages.special.editor_command(r"\edit hello") + + assert not mycli.packages.special.editor_command(r"HELP \e") + assert not mycli.packages.special.editor_command(r"help \edit\g") + assert not mycli.packages.special.editor_command(r"hello") + assert not mycli.packages.special.editor_command(r"\ehello") + assert not mycli.packages.special.editor_command(r"\edithello") + + assert mycli.packages.special.get_filename(r"\e filename") == "filename" + + if os.name != "nt": + assert mycli.packages.special.open_external_editor(sql=r"select 1") == ('select 1', None) + else: + pytest.skip("Skipping on Windows platform.") + + +def test_tee_command(): + mycli.packages.special.write_tee("hello world") # write without file set + # keep Windows from locking the file with delete=False + with tempfile.NamedTemporaryFile(prefix=TEMPFILE_PREFIX, delete=False) as f: + mycli.packages.special.execute(None, "tee " + f.name) + mycli.packages.special.write_tee("hello world") + if os.name == "nt": + assert f.read() == b"hello world\r\n" + else: + assert f.read() == b"hello world\n" + + mycli.packages.special.execute(None, "tee -o " + f.name) + mycli.packages.special.write_tee("hello world") + f.seek(0) + if os.name == "nt": + assert f.read() == b"hello world\r\n" + else: + assert f.read() == b"hello world\n" + + mycli.packages.special.execute(None, "notee") + mycli.packages.special.write_tee("hello world") + f.seek(0) + if os.name == "nt": + assert f.read() == b"hello world\r\n" + else: + assert f.read() == b"hello world\n" + + # remove temp file + # delete=False means we should try to clean up + try: + if os.path.exists(f.name): + os.remove(f.name) + except Exception as e: + print(f"An error occurred while attempting to delete the file: {e}") + + +def test_tee_command_error(): + with pytest.raises(TypeError): + mycli.packages.special.execute(None, "tee") + + with pytest.raises(OSError): + with tempfile.NamedTemporaryFile(prefix=TEMPFILE_PREFIX) as f: + os.chmod(f.name, stat.S_IRUSR | stat.S_IRGRP | stat.S_IROTH) + mycli.packages.special.execute(None, f"tee {f.name}") + + +@dbtest +@pytest.mark.skipif(os.name == "nt", reason="Bug: fails on Windows, needs fixing, singleton of FQ not working right") +def test_favorite_query(favorite_queries_instance) -> None: + with db_connection().cursor() as cur: + query = 'select "✔"' + mycli.packages.special.execute(cur, f"\\fs check {query}") + assert next(mycli.packages.special.execute(cur, "\\f check")).preamble == "> " + query + + +@dbtest +@pytest.mark.skipif(os.name == "nt", reason="Bug: fails on Windows, needs fixing, singleton of FQ not working right") +def test_special_favorite_query(favorite_queries_instance) -> None: + with db_connection().cursor() as cur: + query = r'\?' + mycli.packages.special.execute(cur, rf"\fs special {query}") + assert (r'\G', None, r'\G', 'Display query results vertically.') in next( + mycli.packages.special.execute(cur, r'\f special') + ).rows + + +def test_once_command(): + with pytest.raises(TypeError): + mycli.packages.special.execute(None, "\\once") + + with pytest.raises(OSError): + mycli.packages.special.execute(None, "\\once /proc/access-denied") + + mycli.packages.special.write_once("hello world") # write without file set + # keep Windows from locking the file with delete=False + with tempfile.NamedTemporaryFile(prefix=TEMPFILE_PREFIX, delete=False) as f: + mycli.packages.special.execute(None, "\\once " + f.name) + mycli.packages.special.write_once("hello world") + if os.name == "nt": + assert f.read() == b"hello world\r\n" + else: + assert f.read() == b"hello world\n" + + mycli.packages.special.execute(None, "\\once -o " + f.name) + mycli.packages.special.write_once("hello world line 1") + mycli.packages.special.write_once("hello world line 2") + f.seek(0) + if os.name == "nt": + assert f.read() == b"hello world line 1\r\nhello world line 2\r\n" + else: + assert f.read() == b"hello world line 1\nhello world line 2\n" + # delete=False means we should try to clean up + try: + if os.path.exists(f.name): + os.remove(f.name) + except Exception as e: + print(f"An error occurred while attempting to delete the file: {e}") + + +def test_pipe_once_command(): + with pytest.raises(IOError): + mycli.packages.special.execute(None, "\\pipe_once") + + with pytest.raises(OSError): + mycli.packages.special.execute(None, "\\pipe_once /proc/access-denied") + mycli.packages.special.write_pipe_once("select 1") + mycli.packages.special.flush_pipe_once_if_written(None) + + if os.name == "nt": + mycli.packages.special.execute(None, '\\pipe_once python -c "import sys; print(len(sys.stdin.read().strip()))"') + mycli.packages.special.write_once("hello world") + mycli.packages.special.flush_pipe_once_if_written(None) + else: + with tempfile.NamedTemporaryFile(prefix=TEMPFILE_PREFIX) as f: + mycli.packages.special.execute(None, "\\pipe_once tee " + f.name) + mycli.packages.special.write_pipe_once("hello world") + mycli.packages.special.flush_pipe_once_if_written(None) + f.seek(0) + assert f.read() == b"hello world\n" + + +def test_parseargfile(): + """Test that parseargfile expands the user directory.""" + expected = (os.path.join(os.path.expanduser("~"), "filename"), "a") + + if os.name == "nt": + assert expected == mycli.packages.special.iocommands.parseargfile("~\\filename") + else: + assert expected == mycli.packages.special.iocommands.parseargfile("~/filename") + + expected = (os.path.join(os.path.expanduser("~"), "filename"), "w") + if os.name == "nt": + assert expected == mycli.packages.special.iocommands.parseargfile("-o ~\\filename") + else: + assert expected == mycli.packages.special.iocommands.parseargfile("-o ~/filename") + + +def test_parseargfile_no_file(): + """Test that parseargfile raises a TypeError if there is no filename.""" + with pytest.raises(TypeError): + mycli.packages.special.iocommands.parseargfile("") + + with pytest.raises(TypeError): + mycli.packages.special.iocommands.parseargfile("-o ") + + +@dbtest +def test_watch_query_iteration(): + """Test that a single iteration of the result of `watch_query` executes + the desired query and returns the given results.""" + expected_value = "1" + query = f"SELECT {expected_value}" + expected_preamble = f"> {query}" + with db_connection().cursor() as cur: + result = next(mycli.packages.special.iocommands.watch_query(arg=query, cur=cur)) + assert result.preamble == expected_preamble + assert result.header[0] == expected_value + + +@dbtest +@pytest.mark.skipif(os.name == "nt", reason="Bug: Win handles this differently. May need to refactor watch_query to work for Win") +def test_watch_query_full(): + """Test that `watch_query`: + + * Returns the expected results. + * Executes the defined times inside the given interval, in this case with + a 0.3 seconds wait, it should execute 4 times inside a 1 seconds + interval. + * Stops at Ctrl-C + + """ + watch_seconds = 0.3 + wait_interval = 1 + expected_value = "1" + query = f"SELECT {expected_value}" + expected_preamble = f"> {query}" + # Python 3.14 is skipping ahead to 6 or 7 + # Python 3.11 is as slow as 3 + expected_results = [3, 4, 5, 6, 7] + ctrl_c_process = send_ctrl_c(wait_interval) + with db_connection().cursor() as cur: + results = list(mycli.packages.special.iocommands.watch_query(arg=f"{watch_seconds} {query}", cur=cur)) + ctrl_c_process.join(1) + assert len(results) in expected_results + for result in results: + assert result.preamble == expected_preamble + assert result.header[0] == expected_value + + +@dbtest +@patch("click.clear") +def test_watch_query_clear(clear_mock): + """Test that the screen is cleared with the -c flag of `watch` command + before execute the query.""" + with db_connection().cursor() as cur: + watch_gen = mycli.packages.special.iocommands.watch_query(arg="0.1 -c select 1;", cur=cur) + assert not clear_mock.called + next(watch_gen) + assert clear_mock.called + clear_mock.reset_mock() + next(watch_gen) + assert clear_mock.called + clear_mock.reset_mock() + + +@dbtest +def test_watch_query_bad_arguments(): + """Test different incorrect combinations of arguments for `watch` + command.""" + watch_query = mycli.packages.special.iocommands.watch_query + with db_connection().cursor() as cur: + with pytest.raises(ProgrammingError): + next(watch_query("a select 1;", cur=cur)) + with pytest.raises(ProgrammingError): + next(watch_query("-a select 1;", cur=cur)) + with pytest.raises(ProgrammingError): + next(watch_query("1 -a select 1;", cur=cur)) + with pytest.raises(ProgrammingError): + next(watch_query("-c -a select 1;", cur=cur)) + + +@dbtest +@patch("click.clear") +def test_watch_query_interval_clear(clear_mock): + """Test `watch` command with interval and clear flag.""" + + def test_asserts(gen): + clear_mock.reset_mock() + start = time() + next(gen) + assert clear_mock.called + next(gen) + exec_time = time() - start + assert exec_time > seconds and exec_time < (seconds + seconds) + + seconds = 1.0 + watch_query = mycli.packages.special.iocommands.watch_query + with db_connection().cursor() as cur: + test_asserts(watch_query(f"{seconds} -c select 1;", cur=cur)) + test_asserts(watch_query(f"-c {seconds} select 1;", cur=cur)) + + +def test_split_sql_by_delimiter(): + for delimiter_str in (";", "$", "😀"): + mycli.packages.special.set_delimiter(delimiter_str) + sql_input = f"select 1{delimiter_str} select \ufffc2" + queries = ("select 1", "select \ufffc2") + for query, parsed_query in zip(queries, mycli.packages.special.split_queries(sql_input), strict=True): + assert query == parsed_query + + +def test_switch_delimiter_within_query(): + mycli.packages.special.set_delimiter(";") + sql_input = "select 1; delimiter $$ select 2 $$ select 3 $$" + queries = ("select 1", "delimiter $$ select 2 $$ select 3 $$") + for query, parsed_query in zip(queries, mycli.packages.special.split_queries(sql_input), strict=True): + assert query == parsed_query + + +def test_set_delimiter(): + for delim in ("foo", "bar"): + mycli.packages.special.set_delimiter(delim) + assert mycli.packages.special.get_current_delimiter() == delim + + +def teardown_function(): + mycli.packages.special.set_delimiter(";") + + +def test_simple_setters_and_toggle_timing() -> None: + config = {'favorite_queries': {'demo': 'select 1'}} + + iocommands.set_favorite_queries(config) + assert iocommands.favoritequeries.config is config + + iocommands.set_show_favorite_query(False) + assert iocommands.is_show_favorite_query() is False + + iocommands.set_show_warnings_enabled(True) + assert iocommands.is_show_warnings_enabled() is True + iocommands.set_show_warnings_enabled(False) + assert iocommands.is_show_warnings_enabled() is False + + iocommands.set_destructive_keywords(['drop']) + assert iocommands.DESTRUCTIVE_KEYWORDS == ['drop'] + + iocommands.set_forced_horizontal_output(True) + assert iocommands.forced_horizontal() is True + + iocommands.set_timing_enabled(False) + assert iocommands.toggle_timing()[0].status == 'Timing is on.' + assert iocommands.toggle_timing()[0].status == 'Timing is off.' + + +def test_enable_show_warnings_updates_special_state() -> None: + result = next(iocommands.enable_show_warnings()) + + assert result.status == 'Show warnings enabled.' + assert iocommands.is_show_warnings_enabled() is True + + +def test_disable_show_warnings_updates_special_state() -> None: + result = next(iocommands.disable_show_warnings()) + + assert result.status == 'Show warnings disabled.' + assert iocommands.is_show_warnings_enabled() is False + + +def test_editor_helpers_strip_commands() -> None: + assert iocommands.get_filename(r'\edit ') is None + assert iocommands.get_filename('select 1') is None + assert iocommands.get_editor_query(r' select * from style\edit\e ') == 'select * from style' + + +def test_open_external_editor_filename_paths(monkeypatch, tmp_path: Path) -> None: + filename = tmp_path / 'query.sql' + filename.write_text('select 1\n', encoding='utf-8') + edit_calls: list[str] = [] + + monkeypatch.setattr(iocommands.click, 'edit', lambda filename: edit_calls.append(filename)) + query, message = iocommands.open_external_editor(filename=f'{filename} ignored', sql='unused') + + assert query == 'select 1' + assert message is None + assert edit_calls == [str(filename)] + + def raise_ioerror(*_args, **_kwargs): + raise IOError('boom') + + monkeypatch.setattr(iocommands.click, 'edit', lambda filename: None) + monkeypatch.setattr(builtins, 'open', raise_ioerror) + + query, message = iocommands.open_external_editor(filename=str(filename)) + + assert query == '' + assert message == f'Error reading file: {filename}' + + +def test_open_external_editor_without_filename(monkeypatch) -> None: + calls: list[tuple[str, str]] = [] + marker = '# Type your query above this line.\n' + + def fake_edit(text: str, extension: str) -> str: + calls.append((text, extension)) + return f'select 1\n\n{marker}ignored' + + monkeypatch.setattr(iocommands.click, 'edit', fake_edit) + query, message = iocommands.open_external_editor(sql='select 1') + + assert query == 'select 1' + assert message is None + assert calls == [(f'select 1\n\n{marker}', '.sql')] + + monkeypatch.setattr(iocommands.click, 'edit', lambda text, extension: None) + query, message = iocommands.open_external_editor(sql='select fallback') + + assert query == 'select fallback' + assert message is None + + +def test_clip_helpers_and_clipboard(monkeypatch) -> None: + assert iocommands.clip_command(r'\clip select 1') + assert iocommands.clip_command(r'select 1 \clip') + assert not iocommands.clip_command(r'select 1') + assert iocommands.get_clip_query(r'\clip select 1\clip') == ' select 1' + + copied: list[str] = [] + monkeypatch.setattr(iocommands.pyperclip, 'copy', lambda text: copied.append(text)) + assert iocommands.copy_query_to_clipboard('select 1') is None + assert copied == ['select 1'] + + def raise_runtime_error(_text: str) -> None: + raise RuntimeError('no clipboard') + + monkeypatch.setattr(iocommands.pyperclip, 'copy', raise_runtime_error) + assert iocommands.copy_query_to_clipboard() == 'Error clipping query: no clipboard.' + + +def test_set_redirect_routes_to_pipe_once_and_once(monkeypatch) -> None: + pipe_calls: list[str] = [] + once_calls: list[str] = [] + + def fake_set_pipe_once(arg: str) -> list[tuple[str]]: + pipe_calls.append(arg) + return [('pipe',)] + + def fake_set_once(arg: str) -> list[tuple[str]]: + once_calls.append(arg) + return [('once',)] + + monkeypatch.setattr(iocommands, 'set_pipe_once', fake_set_pipe_once) + monkeypatch.setattr(iocommands, 'set_once', fake_set_once) + + iocommands.PIPE_ONCE['stdout_file'] = None + iocommands.PIPE_ONCE['stdout_mode'] = None + result = iocommands.set_redirect('cat', '>', 'out.txt') + assert result == [('pipe',)] + assert pipe_calls == ['cat'] + assert iocommands.PIPE_ONCE['stdout_file'] == 'out.txt' + assert iocommands.PIPE_ONCE['stdout_mode'] == 'w' + + assert iocommands.set_redirect(None, '>', 'other.txt') == [('once',)] + assert iocommands.set_redirect(None, None, 'append.txt') == [('once',)] + assert once_calls == ['-o other.txt', 'append.txt'] + + +def test_execute_favorite_query_list_missing_and_bad_args(monkeypatch) -> None: + favorite_queries = FakeFavoriteQueries({'demo': 'select $1'}) + monkeypatch.setattr(iocommands.FavoriteQueries, 'instance', favorite_queries, raising=False) + + listed = SQLResult(status='listed') + monkeypatch.setattr(iocommands, 'list_favorite_queries', lambda: [listed]) + assert list(iocommands.execute_favorite_query(FakeCursor(), '')) == [listed] + + missing = list(iocommands.execute_favorite_query(FakeCursor(), 'unknown')) + assert missing[0].status == 'No favorite query: unknown' + + bad_args = list(iocommands.execute_favorite_query(FakeCursor(), 'demo')) + assert bad_args[0].status == 'missing substitution for $1 in query:\n select $1' + + +def test_execute_favorite_query_special_and_plain_sql(monkeypatch) -> None: + favorite_queries = FakeFavoriteQueries({'combo': 'help demo; select 1'}) + monkeypatch.setattr(iocommands.FavoriteQueries, 'instance', favorite_queries, raising=False) + monkeypatch.setattr(iocommands, 'SPECIAL_COMMANDS', {'help': object()}) + monkeypatch.setattr(iocommands, 'special_execute', lambda cur, sql: [SQLResult(status=f'ran {sql}')]) + + cursor = FakeCursor({'select 1': None}) + results = list(iocommands.execute_favorite_query(cursor, 'combo')) + + assert results[0].status == 'ran help demo' + assert results[0].preamble == '> help demo' + assert results[1].preamble == '> select 1' + assert results[1].header is None + assert cursor.executed == ['select 1'] + + +def test_execute_favorite_query_returns_header_for_result_sets(monkeypatch) -> None: + monkeypatch.setattr(iocommands.FavoriteQueries, 'instance', FakeFavoriteQueries({'rows': 'select 2'}), raising=False) + + cursor = FakeCursor({'select 2': [('col',)]}) + results = list(iocommands.execute_favorite_query(cursor, 'rows')) + + assert results[0].preamble == '> select 2' + assert results[0].header == ['col'] + assert results[0].rows is cursor + + +def test_list_substitute_save_delete_and_redirect_state(tmp_path: Path, monkeypatch) -> None: + empty_favorites = FakeFavoriteQueries() + monkeypatch.setattr(iocommands.FavoriteQueries, 'instance', empty_favorites, raising=False) + empty_result = iocommands.list_favorite_queries()[0] + assert empty_result.header == ['Name', 'Query'] + assert empty_result.rows == [] + assert empty_result.status == '\nNo favorite queries found.' + empty_favorites.usage + + populated_favorites = FakeFavoriteQueries({'demo': 'select 1'}) + monkeypatch.setattr(iocommands.FavoriteQueries, 'instance', populated_favorites, raising=False) + rows_result = iocommands.list_favorite_queries()[0] + assert rows_result.rows == [('demo', 'select 1')] + assert rows_result.status == '' + + assert iocommands.subst_favorite_query_args('select $1', ['x']) == ['select x', None] + assert iocommands.subst_favorite_query_args('select 1', ['x']) == [None, 'query does not have substitution parameter $1:\n select 1'] + assert iocommands.subst_favorite_query_args('select $1, $2', ['x']) == [None, 'missing substitution for $2 in query:\n select x, $2'] + + assert iocommands.save_favorite_query('', cur=None)[0].status == 'Syntax: \\fs name query.\n\n' + populated_favorites.usage + assert iocommands.save_favorite_query('onlyname', cur=None)[0].status == ( + 'Syntax: \\fs name query.\n\n' + populated_favorites.usage + ' Err: Both name and query are required.' + ) + assert iocommands.save_favorite_query('saved select 2', cur=None)[0].status == 'Saved.' + assert populated_favorites.saved == [('saved', 'select 2')] + + assert iocommands.delete_favorite_query('', cur=None)[0].status == 'Syntax: \\fd name.\n\n' + populated_favorites.usage + assert iocommands.delete_favorite_query('saved', cur=None)[0].status == 'saved: Deleted.' + assert populated_favorites.deleted == ['saved'] + + iocommands.once_file = None + iocommands.PIPE_ONCE['process'] = None + assert iocommands.is_redirected() is False + redirect_file = (tmp_path / 'redirect.txt').open('w', encoding='utf-8') + iocommands.once_file = redirect_file + assert iocommands.is_redirected() is True + redirect_file.close() + iocommands.once_file = None + iocommands.PIPE_ONCE['process'] = SimpleNamespace() + assert iocommands.is_redirected() is True + + +def test_execute_system_command_usage_parse_and_cd(monkeypatch) -> None: + usage = 'Syntax: system [-r] [command].\n-r denotes "raw" mode, in which output is passed through without formatting.' + assert iocommands.execute_system_command('')[0].status == usage + assert iocommands.execute_system_command('-r')[0].status == usage + + def raise_value_error(*_args, **_kwargs): + raise ValueError('bad quoting') + + monkeypatch.setattr(iocommands.shlex, 'split', raise_value_error) + assert iocommands.execute_system_command('broken')[0].status == 'Cannot parse system command: bad quoting' + + monkeypatch.setattr(iocommands.shlex, 'split', lambda arg, posix: ['cd', '/tmp']) + monkeypatch.setattr(iocommands, 'handle_cd_command', lambda command: (False, 'cd failed')) + assert iocommands.execute_system_command('cd /tmp')[0].status == 'cd failed' + + monkeypatch.setattr(iocommands, 'handle_cd_command', lambda command: (True, None)) + success_result = iocommands.execute_system_command('cd /tmp')[0] + assert success_result.status is None + assert success_result.preamble is None + + +@pytest.mark.parametrize( + ('command', 'returncode', 'expected_status'), + [ + ('-r echo ok', 0, None), + ('vim file.sql', 1, 'Command exited with return code 1'), + ], +) +def test_execute_system_command_raw_modes( + monkeypatch, + command: str, + returncode: int, + expected_status: str | None, +) -> None: + calls: list[list[str]] = [] + + def fake_run(cmd: list[str], check: bool = False) -> SimpleNamespace: + calls.append(cmd) + return SimpleNamespace(returncode=returncode) + + monkeypatch.setattr(iocommands.subprocess, 'run', fake_run) + result = iocommands.execute_system_command(command)[0] + + assert calls + assert result.status == expected_status + + +def test_execute_system_command_nonraw_paths(monkeypatch) -> None: + monkeypatch.setattr(iocommands.locale, 'getpreferredencoding', lambda do_setlocale: 'utf-8') + + timeout_process = FakeProcess(stdout=b'timed out output', stderr=b'', returncode=0, raise_timeout=True) + timeout_popen_calls: list[tuple[list[str], int, int]] = [] + + def fake_timeout_popen(command: list[str], stdout: int, stderr: int) -> FakeProcess: + timeout_popen_calls.append((command, stdout, stderr)) + return timeout_process + + monkeypatch.setattr( + iocommands.subprocess, + 'Popen', + fake_timeout_popen, + ) + result = iocommands.execute_system_command('echo slow')[0] + assert result.preamble == 'timed out output' + assert result.status is None + assert timeout_popen_calls == [ + ( + ['echo', 'slow'], + iocommands.subprocess.PIPE, + iocommands.subprocess.PIPE, + ) + ] + assert timeout_process.communicate_timeouts == [60, None] + assert timeout_process.killed is True + + error_process = FakeProcess(stdout=b'ignored', stderr=b'boom', returncode=7) + error_popen_calls: list[tuple[list[str], int, int]] = [] + + def fake_error_popen(command: list[str], stdout: int, stderr: int) -> FakeProcess: + error_popen_calls.append((command, stdout, stderr)) + return error_process + + monkeypatch.setattr( + iocommands.subprocess, + 'Popen', + fake_error_popen, + ) + error_result = iocommands.execute_system_command('echo fail')[0] + assert error_result.preamble == 'boom' + assert error_result.status == 'Command exited with return code 7' + assert error_popen_calls == [ + ( + ['echo', 'fail'], + iocommands.subprocess.PIPE, + iocommands.subprocess.PIPE, + ) + ] + assert error_process.communicate_timeouts == [60] + + def raise_oserror(command, stdout, stderr): + raise OSError(0, 'bad command') + + monkeypatch.setattr(iocommands.subprocess, 'Popen', raise_oserror) + assert iocommands.execute_system_command('echo nope')[0].status == 'OSError: bad command' + + +def test_unset_once_and_post_redirect_hook(monkeypatch, tmp_path: Path) -> None: + target = tmp_path / 'once.txt' + iocommands.once_file = target.open('w', encoding='utf-8') + iocommands.written_to_once_file = True + hook_calls: list[tuple[str, str]] = [] + original_run_post_redirect_hook = iocommands._run_post_redirect_hook + + def fake_run_post_redirect_hook(command: str, filename: str) -> None: + hook_calls.append((command, filename)) + + monkeypatch.setattr(iocommands, '_run_post_redirect_hook', fake_run_post_redirect_hook) + + iocommands.unset_once_if_written('post {}') + + assert iocommands.once_file is None + assert hook_calls == [('post {}', str(target))] # type: ignore[unreachable] + monkeypatch.setattr(iocommands, '_run_post_redirect_hook', original_run_post_redirect_hook) + + run_calls: list[tuple[tuple[Any, ...], dict[str, Any]]] = [] + + def fake_run(*args, **kwargs) -> SimpleNamespace: + run_calls.append((args, kwargs)) + return SimpleNamespace(returncode=0) + + monkeypatch.setattr(iocommands.subprocess, 'run', fake_run) + iocommands._run_post_redirect_hook('', str(target)) + assert run_calls == [] + + iocommands._run_post_redirect_hook('cat {}', str(target)) + assert run_calls[0][0] == ('cat ' + iocommands.shlex.quote(str(target)),) + assert run_calls[0][1] == { + 'shell': True, + 'check': True, + 'stdin': iocommands.subprocess.DEVNULL, + 'stdout': iocommands.subprocess.DEVNULL, + 'stderr': iocommands.subprocess.DEVNULL, + } + + def raise_run(*_args, **_kwargs): + raise RuntimeError('hook failed') + + monkeypatch.setattr(iocommands.subprocess, 'run', raise_run) + with pytest.raises(OSError, match='Redirect post hook failed: hook failed'): + iocommands._run_post_redirect_hook('cat {}', str(target)) + + +def test_set_pipe_once_and_flush_short_circuits(monkeypatch) -> None: + popen_calls: list[tuple[tuple[Any, ...], dict[str, Any]]] = [] + monkeypatch.setattr(iocommands, 'WIN', True) + monkeypatch.setattr(iocommands.shlex, 'split', lambda arg: ['cmd', '/c', arg]) + + def fake_popen(*args, **kwargs) -> SimpleNamespace: + popen_calls.append((args, kwargs)) + return SimpleNamespace() + + monkeypatch.setattr(iocommands.subprocess, 'Popen', fake_popen) + + assert iocommands.set_pipe_once('echo test')[0].status == '' + assert popen_calls == [ + ( + (['cmd', '/c', 'echo test'],), + { + 'stdin': iocommands.subprocess.PIPE, + 'stdout': iocommands.subprocess.PIPE, + 'stderr': iocommands.subprocess.PIPE, + 'encoding': 'UTF-8', + 'universal_newlines': True, + }, + ) + ] + + iocommands.PIPE_ONCE['process'] = None + iocommands.PIPE_ONCE['stdin'] = ['line'] + iocommands.flush_pipe_once_if_written('post {}') + + iocommands.PIPE_ONCE['process'] = SimpleNamespace() + iocommands.PIPE_ONCE['stdin'] = [] + iocommands.flush_pipe_once_if_written('post {}') + + +def test_flush_pipe_once_timeout_and_nonzero_exit(monkeypatch, tmp_path: Path) -> None: + output_file = tmp_path / 'pipe.txt' + process = FakeProcess(stdout='stdout data', stderr='stderr data', returncode=9, raise_timeout=True) + hook_calls: list[tuple[str, str]] = [] + secho_calls: list[tuple[str, dict[str, Any]]] = [] + + monkeypatch.setattr(iocommands, '_run_post_redirect_hook', lambda command, filename: hook_calls.append((command, filename))) + monkeypatch.setattr(iocommands.click, 'secho', lambda message, **kwargs: secho_calls.append((message, kwargs))) + + iocommands.PIPE_ONCE['process'] = process + iocommands.PIPE_ONCE['stdin'] = ['select 1'] + iocommands.PIPE_ONCE['stdout_file'] = str(output_file) + iocommands.PIPE_ONCE['stdout_mode'] = 'w' + + with pytest.raises(OSError, match='process exited with nonzero code 9'): + iocommands.flush_pipe_once_if_written('post {}') + + assert process.killed is True + assert output_file.read_text(encoding='utf-8') == 'stdout data\n' + assert hook_calls == [('post {}', str(output_file))] + assert secho_calls == [('stderr data', {'err': True, 'fg': 'red'})] + assert iocommands.PIPE_ONCE == { + 'process': None, + 'stdin': [], + 'stdout_file': None, + 'stdout_mode': None, + } + + +def test_watch_query_usage_and_destructive_cancel(monkeypatch) -> None: + usage_results = list(iocommands.watch_query('', cur=SequenceCursor([None]))) + assert usage_results[0].status and usage_results[0].status.startswith('Syntax: watch') + + usage_missing_statement = list(iocommands.watch_query('5 -c', cur=SequenceCursor([None]))) + assert usage_missing_statement[0].status and usage_missing_statement[0].status.startswith('Syntax: watch') + + secho_calls: list[str] = [] + monkeypatch.setattr(iocommands, 'confirm_destructive_query', lambda keywords, statement: False) + monkeypatch.setattr(iocommands.click, 'secho', lambda message, **kwargs: secho_calls.append(message)) + + assert list(iocommands.watch_query('drop table t', cur=SequenceCursor([None]))) == [] + assert secho_calls == ['Wise choice!'] + + +def test_watch_query_confirmed_without_description_and_keyboard_interrupt(monkeypatch) -> None: + cursor = SequenceCursor([None]) + secho_calls: list[str] = [] + + monkeypatch.setattr(iocommands, 'confirm_destructive_query', lambda keywords, statement: True) + monkeypatch.setattr(iocommands.click, 'secho', lambda message, **kwargs: secho_calls.append(message)) + monkeypatch.setattr(iocommands, 'sleep', lambda seconds: (_ for _ in ()).throw(KeyboardInterrupt())) + + iocommands.set_pager_enabled(True) + generator = iocommands.watch_query('0.1 select 1;', cur=cursor) + result = next(generator) + + assert result.preamble == '> select 1;' + assert result.header is None + assert result.command == {'name': 'watch', 'seconds': 0.1} + assert iocommands.is_pager_enabled() is False + + with pytest.raises(StopIteration): + next(generator) + + assert secho_calls == ['Your call!', ''] + assert iocommands.is_pager_enabled() is True diff --git a/test/pytests/test_special_llm.py b/test/pytests/test_special_llm.py new file mode 100644 index 00000000..9ca28150 --- /dev/null +++ b/test/pytests/test_special_llm.py @@ -0,0 +1,548 @@ +import builtins +import importlib +from types import SimpleNamespace +from typing import Any, cast +from unittest.mock import patch + +import click +import pytest + +from mycli.packages.special import llm as llm_module +from mycli.packages.special.llm import ( + NEED_DEPENDENCIES, + USAGE, + _build_command_tree, + build_command_tree, + ensure_mycli_template, + get_completions, + get_sample_data, + get_schema, + handle_llm, + is_llm_command, + run_external_cmd, + sql_using_llm, + truncate_list_elements, + truncate_table_lines, +) +from mycli.packages.special.main import COMMANDS +from mycli.packages.sqlresult import SQLResult + + +# Override executor fixture to avoid real DB connections during llm tests +@pytest.fixture +def executor(): + """Dummy executor fixture""" + return None + + +def test_reload_llm_module_handles_disabled_and_import_error_paths(monkeypatch) -> None: + with monkeypatch.context() as m: + m.setenv("MYCLI_LLM_OFF", "1") + importlib.reload(llm_module) + assert llm_module.LLM_IMPORTED is False + assert llm_module.LLM_CLI_IMPORTED is False + + importlib.reload(llm_module) + + original_import = builtins.__import__ + + def fake_import(name, globals=None, locals=None, fromlist=(), level=0): # noqa: A002 + if name == "llm" or name == "llm.cli": + raise ImportError("no llm") + return original_import(name, globals, locals, fromlist, level) + + with monkeypatch.context() as m: + m.delenv("MYCLI_LLM_OFF", raising=False) + m.setattr(builtins, "__import__", fake_import) + importlib.reload(llm_module) + assert llm_module.LLM_IMPORTED is False + assert llm_module.LLM_CLI_IMPORTED is False + + importlib.reload(llm_module) + + +def test_reload_llm_module_handles_cli_import_error(monkeypatch) -> None: + original_import = builtins.__import__ + + def fake_import(name, globals=None, locals=None, fromlist=(), level=0): # noqa: A002 + if name == "llm.cli": + raise ImportError("no llm cli") + return original_import(name, globals, locals, fromlist, level) + + with monkeypatch.context() as m: + m.delenv("MYCLI_LLM_OFF", raising=False) + m.setattr(builtins, "__import__", fake_import) + importlib.reload(llm_module) + assert llm_module.LLM_IMPORTED is True + assert llm_module.LLM_CLI_IMPORTED is False + + importlib.reload(llm_module) + + +def test_build_command_tree_handles_groups_models_and_leaf(monkeypatch) -> None: + monkeypatch.setattr( + llm_module, + "llm", + SimpleNamespace(get_models=lambda: [SimpleNamespace(model_id="gpt-4o"), SimpleNamespace(model_id="llama3")]), + raising=False, + ) + + models_group = click.Group("models") + models_group.add_command(click.Command("default")) + root = click.Group("root") + root.add_command(click.Command("prompt")) + root.add_command(models_group) + + assert _build_command_tree(root) == { + "prompt": None, + "models": {"default": {"gpt-4o": None, "llama3": None}}, + } + assert build_command_tree(click.Command("leaf")) == {} + + +def test_get_completions_walks_tree_and_skips_flags() -> None: + tree = { + "models": {"default": {"gpt-4o": None}}, + "prompt": None, + } + + assert get_completions([], tree) == ["models", "prompt"] + assert get_completions(["models"], tree) == ["default"] + assert get_completions(["models", "--help"], tree) == ["default"] + assert get_completions(["models", "default"], tree) == ["gpt-4o"] + assert get_completions(["missing"], tree) == [] + assert get_completions(["prompt"], tree) == [] + + +def test_cli_commands_is_cached(monkeypatch) -> None: + llm_module.cli_commands.cache_clear() + monkeypatch.setattr(llm_module, "cli", SimpleNamespace(commands={"models": object(), "prompt": object()})) + + assert llm_module.cli_commands() == ["models", "prompt"] + + monkeypatch.setattr(llm_module, "cli", SimpleNamespace(commands={"install": object()})) + assert llm_module.cli_commands() == ["models", "prompt"] + llm_module.cli_commands.cache_clear() + + +def test_run_external_cmd_capture_output_and_restore_argv(monkeypatch, capsys) -> None: + original_argv = list(llm_module.sys.argv) + + def fake_run_module(cmd: str, run_name: str) -> None: + assert cmd == "llm" + assert run_name == "__main__" + print("stdout text") + llm_module.sys.stderr.write("stderr text") + + monkeypatch.setattr(llm_module, "run_module", fake_run_module) + + code, output = run_external_cmd("llm", "models", capture_output=True) + + assert code == 0 + assert "stdout text" in output + assert "stderr text" in output + assert llm_module.sys.argv == original_argv + captured = capsys.readouterr() + assert captured.out == "" + assert captured.err == "" + + +def test_run_external_cmd_nonzero_exit_raises_with_output(monkeypatch) -> None: + def fake_run_module(cmd: str, run_name: str) -> None: + print("failed output") + raise SystemExit(2) + + monkeypatch.setattr(llm_module, "run_module", fake_run_module) + + with pytest.raises(RuntimeError, match="failed output"): + run_external_cmd("llm", capture_output=True) + + +def test_run_external_cmd_nonzero_exit_raises_without_output(monkeypatch) -> None: + monkeypatch.setattr(llm_module, "run_module", lambda cmd, run_name: (_ for _ in ()).throw(SystemExit(3))) + + with pytest.raises(RuntimeError, match=r"Command llm failed with exit code 3\."): + run_external_cmd("llm") + + +def test_run_external_cmd_exception_paths_and_restart(monkeypatch) -> None: + monkeypatch.setattr(llm_module, "run_module", lambda cmd, run_name: (_ for _ in ()).throw(ValueError("boom"))) + + with pytest.raises(RuntimeError, match=r"Command llm failed: boom"): + run_external_cmd("llm") + + def fake_run_module_capture(cmd: str, run_name: str) -> None: + print("capture boom") + raise ValueError("boom") + + monkeypatch.setattr(llm_module, "run_module", fake_run_module_capture) + with pytest.raises(RuntimeError, match="capture boom"): + run_external_cmd("llm", capture_output=True) + + execv_calls: list[tuple[str, list[str]]] = [] + monkeypatch.setattr(llm_module, "run_module", lambda cmd, run_name: (_ for _ in ()).throw(SystemExit(0))) + monkeypatch.setattr(llm_module.os, "execv", lambda exe, args: execv_calls.append((exe, args))) + + code, output = run_external_cmd("llm", "install", restart_cli=True) + + assert code == 0 + assert output == "" + assert execv_calls == [(llm_module.sys.executable, [llm_module.sys.executable] + llm_module.sys.argv)] + + +def test_ensure_mycli_template_returns_early_or_replaces(monkeypatch) -> None: + calls: list[tuple] = [] + + def fake_run_external_cmd(*args, **kwargs): + calls.append((args, kwargs)) + return (0, "") + + monkeypatch.setattr(llm_module, "run_external_cmd", fake_run_external_cmd) + ensure_mycli_template() + + assert calls == [ + (("llm", "templates", "show", llm_module.LLM_TEMPLATE_NAME), {"capture_output": True, "raise_exception": False}), + ] + + calls.clear() + + def fake_run_external_cmd_missing(*args, **kwargs): + calls.append((args, kwargs)) + return (1, "") if len(calls) == 1 else (0, "") + + monkeypatch.setattr(llm_module, "run_external_cmd", fake_run_external_cmd_missing) + ensure_mycli_template() + + assert calls == [ + (("llm", "templates", "show", llm_module.LLM_TEMPLATE_NAME), {"capture_output": True, "raise_exception": False}), + (("llm", llm_module.PROMPT, "--save", llm_module.LLM_TEMPLATE_NAME), {}), + ] + + calls.clear() + monkeypatch.setattr(llm_module, "run_external_cmd", fake_run_external_cmd) + ensure_mycli_template(replace=True) + + assert calls == [ + (("llm", llm_module.PROMPT, "--save", llm_module.LLM_TEMPLATE_NAME), {}), + ] + + +@patch("mycli.packages.special.llm.llm") +def test_llm_command_without_args(mock_llm, executor): + r""" + Invoking \llm without any arguments should print the usage and raise FinishIteration. + """ + assert mock_llm is not None + test_text = r"\llm" + with pytest.raises(llm_module.FinishIteration) as exc_info: + handle_llm(test_text, executor, 'mysql', 0, 0) + # Should return usage message when no args provided + assert exc_info.value.results == [SQLResult(preamble=USAGE)] + + +@patch("mycli.packages.special.llm.llm") +def test_llm_command_with_help_subcommand(mock_llm, executor): + r""" + Invoking \llm with "help" should print the usage and raise FinishIteration. + """ + assert mock_llm is not None + test_text = r"\llm help" + with pytest.raises(llm_module.FinishIteration) as exc_info: + handle_llm(test_text, executor, 'mysql', 0, 0) + # Should return usage message when "help" subcommand or variant is provided + assert exc_info.value.results == [SQLResult(preamble=USAGE)] + + +@patch("mycli.packages.special.llm.llm") +@patch("mycli.packages.special.llm.run_external_cmd") +def test_llm_command_with_c_flag(mock_run_cmd, mock_llm, executor): + string = "Hello, no SQL today." + # Suppose the LLM returns some text without fenced SQL + mock_run_cmd.return_value = (0, string) + test_text = r"\llm -c 'Something?'" + with pytest.raises(llm_module.FinishIteration) as exc_info: + handle_llm(test_text, executor, 'mysql', 0, 0) + # Expect raw output when no SQL fence found + assert exc_info.value.results == [SQLResult(preamble=string)] + + +@patch("mycli.packages.special.llm.llm") +@patch("mycli.packages.special.llm.run_external_cmd") +def test_llm_command_with_c_flag_and_fenced_sql(mock_run_cmd, mock_llm, executor): + # Return text containing a fenced SQL block + sql_text = "SELECT * FROM users;" + fenced = f"Here you go:\n```sql\n{sql_text}\n```" + mock_run_cmd.return_value = (0, fenced) + test_text = r"\llm -c 'Rewrite SQL'" + result, sql, duration = handle_llm(test_text, executor, 'mysql', 0, 0) + # Without verbosity, result is empty, sql extracted + assert sql == sql_text + assert result == "" + assert isinstance(duration, float) + + +@patch("mycli.packages.special.llm.llm") +@patch("mycli.packages.special.llm.run_external_cmd") +def test_llm_command_known_subcommand(mock_run_cmd, mock_llm, executor): + # 'models' is a known subcommand + test_text = r"\llm models" + with pytest.raises(llm_module.FinishIteration) as exc_info: + handle_llm(test_text, executor, 'mysql', 0, 0) + mock_run_cmd.assert_called_once_with("llm", "models", restart_cli=False) + assert exc_info.value.results is None + + +@patch("mycli.packages.special.llm.llm") +@patch("mycli.packages.special.llm.run_external_cmd") +def test_llm_command_with_help_flag(mock_run_cmd, mock_llm, executor): + test_text = r"\llm --help" + with pytest.raises(llm_module.FinishIteration) as exc_info: + handle_llm(test_text, executor, 'mysql', 0, 0) + mock_run_cmd.assert_called_once_with("llm", "--help", restart_cli=False) + assert exc_info.value.results is None + + +@patch("mycli.packages.special.llm.llm") +@patch("mycli.packages.special.llm.run_external_cmd") +def test_llm_command_with_install_flag(mock_run_cmd, mock_llm, executor): + test_text = r"\llm install openai" + with pytest.raises(llm_module.FinishIteration) as exc_info: + handle_llm(test_text, executor, 'mysql', 0, 0) + mock_run_cmd.assert_called_once_with("llm", "install", "openai", restart_cli=True) + assert exc_info.value.results is None + + +@patch("mycli.packages.special.llm.llm") +@patch("mycli.packages.special.llm.ensure_mycli_template") +@patch("mycli.packages.special.llm.sql_using_llm") +def test_llm_command_with_prompt(mock_sql_using_llm, mock_ensure_template, mock_llm, executor): + r""" + \llm prompt 'question' should use template and call sql_using_llm + """ + mock_sql_using_llm.return_value = ("CTX", "SELECT 1;") + test_text = r"\llm prompt 'Test?'" + context, sql, duration = handle_llm(test_text, executor, 'mysql', 0, 0) + mock_ensure_template.assert_called_once() + mock_sql_using_llm.assert_called() + assert context == "CTX" + assert sql == "SELECT 1;" + assert isinstance(duration, float) + + +@patch("mycli.packages.special.llm.llm") +@patch("mycli.packages.special.llm.ensure_mycli_template") +@patch("mycli.packages.special.llm.sql_using_llm") +def test_llm_command_question_with_context(mock_sql_using_llm, mock_ensure_template, mock_llm, executor): + r""" + \llm 'question' treats as prompt and returns SQL + """ + mock_sql_using_llm.return_value = ("CTX2", "SELECT 2;") + test_text = r"\llm 'Top 10?'" + context, sql, duration = handle_llm(test_text, executor, 'mysql', 0, 0) + mock_ensure_template.assert_called_once() + mock_sql_using_llm.assert_called() + assert context == "CTX2" + assert sql == "SELECT 2;" + assert isinstance(duration, float) + + +@patch("mycli.packages.special.llm.llm") +@patch("mycli.packages.special.llm.ensure_mycli_template") +@patch("mycli.packages.special.llm.sql_using_llm") +def test_llm_command_question_verbose(mock_sql_using_llm, mock_ensure_template, mock_llm, executor): + r""" + \llm+ returns verbose context and SQL + """ + mock_sql_using_llm.return_value = ("NO_CTX", "SELECT 42;") + test_text = r"\llm- 'Succinct?'" + context, sql, duration = handle_llm(test_text, executor, 'mysql', 0, 0) + assert context == "" + assert sql == "SELECT 42;" + assert isinstance(duration, float) + + +def test_handle_llm_without_dependencies(executor, monkeypatch) -> None: + monkeypatch.setattr(llm_module, "LLM_IMPORTED", False) + + with pytest.raises(llm_module.FinishIteration) as exc_info: + handle_llm(r"\llm anything", executor, "mysql", 0, 0) + + assert exc_info.value.results == [SQLResult(preamble=NEED_DEPENDENCIES)] + + +@patch("mycli.packages.special.llm.llm") +def test_handle_llm_wraps_context_errors(mock_llm, executor, monkeypatch) -> None: + assert mock_llm is not None + monkeypatch.setattr(llm_module, "ensure_mycli_template", lambda: (_ for _ in ()).throw(ValueError("bad template"))) + + with pytest.raises(RuntimeError, match="bad template"): + handle_llm(r"\llm 'Top 10?'", executor, "mysql", 0, 0) + + +def test_is_llm_command(): + # Valid llm command variants + for cmd in ["\\llm", "\\ai"]: + assert is_llm_command(cmd + " 'x'") + # Invalid commands + assert not is_llm_command("select * from table;") + + +def test_sql_using_llm_no_connection(): + # Should error if no database cursor provided + with pytest.raises(RuntimeError) as exc_info: + sql_using_llm(None, question="test") + assert "Connect to a database" in str(exc_info.value) + + +def test_truncate_list_elements_and_table_lines(monkeypatch) -> None: + monkeypatch.setattr(llm_module.sys, "getsizeof", lambda value: len(value) if isinstance(value, (str, bytes)) else 8) + + row = ["a" * 250, b"b" * 250, 1] + truncated = truncate_list_elements(row, prompt_field_truncate=250, prompt_section_truncate=300) + assert truncated == ["a" * 50, b"b" * 50, 1] + assert truncate_list_elements(row, prompt_field_truncate=0, prompt_section_truncate=0) is row + assert truncate_list_elements(["abcdef"], prompt_field_truncate=3, prompt_section_truncate=0) == ["abc"] + + table = ["a" * 100, "b" * 100, "c" * 100] + assert truncate_table_lines(table.copy(), prompt_section_truncate=0) == table + assert truncate_table_lines(table.copy(), prompt_section_truncate=210) == ["a" * 100, "b" * 100] + assert truncate_table_lines(table.copy(), prompt_section_truncate=150) == ["a" * 100] + assert truncate_table_lines(["a" * 100], prompt_section_truncate=50) == [] + + +def test_get_schema_and_sample_data_use_cache_and_skip_bad_rows(monkeypatch) -> None: + llm_module.SCHEMA_DATA_CACHE.clear() + llm_module.SAMPLE_DATA_CACHE.clear() + monkeypatch.setattr(llm_module.click, "echo", lambda message: None) + monkeypatch.setattr(llm_module.sys, "getsizeof", lambda value: len(value) if isinstance(value, (str, bytes)) else 8) + + class DummyCursor: + def __init__(self) -> None: + self.executed: list[str] = [] + self.description: list[tuple[str, None]] = [] + self._rows: list[tuple[str]] = [] + self._row: tuple[int, str] | None = None + + def execute(self, query: str) -> None: + self.executed.append(query) + if "information_schema.columns" in query: + self._rows = [("orders(id int)",), ("users(name text)",)] + return + if query == "SHOW TABLES": + self._rows = [("orders",), ("broken",), ("empty",)] + return + if "`orders`" in query: + self.description = [("id", None), ("name", None)] + self._row = (1, "alice") + return + if "`broken`" in query: + raise RuntimeError("bad table") + if "`empty`" in query: + self.description = [("id", None)] + self._row = None + return + raise AssertionError(f"unexpected query: {query}") + + def fetchall(self) -> list[tuple[str]]: + return self._rows + + def fetchone(self) -> tuple[int, str] | None: + return self._row + + cursor = DummyCursor() + + assert get_schema(cast(Any, cursor), "mysql", 0) == "orders(id int)\nusers(name text)" + assert get_schema(cast(Any, cursor), "mysql", 0) == "orders(id int)\nusers(name text)" + sample_data = get_sample_data(cast(Any, cursor), "mysql", 10, 100) + assert sample_data == {"orders": [("id", 1), ("name", "alice")]} + assert get_sample_data(cast(Any, cursor), "mysql", 10, 100) == sample_data + assert cursor.executed.count("SHOW TABLES") == 1 + assert sum(1 for query in cursor.executed if "information_schema.columns" in query) == 1 + + +# Test sql_using_llm with dummy cursor and fenced SQL output +@patch("mycli.packages.special.llm.run_external_cmd") +def test_sql_using_llm_success(mock_run_cmd): + llm_module.SCHEMA_DATA_CACHE.clear() + llm_module.SAMPLE_DATA_CACHE.clear() + + # Dummy cursor simulating database schema and sample data + class DummyCursor: + def __init__(self): + self._last = [] + self.executed = [] + + def execute(self, query): + self.executed.append(query) + if "information_schema.columns" in query: + self._last = [("table1(col1 int,col2 text)",), ("table2(colA varchar(20))",)] + elif query.strip().upper().startswith("SHOW TABLES"): + self._last = [("table1",), ("table2",)] + elif query.strip().upper().startswith("SELECT * FROM"): + self.description = [("col1", None), ("col2", None)] + self._row = (1, "abc") + + def fetchall(self): + return getattr(self, "_last", []) + + def fetchone(self): + return getattr(self, "_row", None) + + dummy_cur = DummyCursor() + # Simulate llm CLI returning a fenced SQL result + sql_text = "SELECT 1, 'abc';" + fenced = f"Note\n```sql\n{sql_text}\n```" + mock_run_cmd.return_value = (0, fenced) + result, sql = sql_using_llm(dummy_cur, question="dummy", dbname='mysql') + + assert any("information_schema.columns" in query for query in dummy_cur.executed) + assert "SHOW TABLES" in dummy_cur.executed + assert any(query.strip().upper().startswith("SELECT * FROM") for query in dummy_cur.executed) + mock_run_cmd.assert_called_once_with( + "llm", + "--template", + llm_module.LLM_TEMPLATE_NAME, + "--param", + "db_schema", + "table1(col1 int,col2 text)\ntable2(colA varchar(20))", + "--param", + "sample_data", + {"table1": [("col1", 1), ("col2", "abc")], "table2": [("col1", 1), ("col2", "abc")]}, + "--param", + "question", + "dummy", + " ", + capture_output=True, + ) + assert result == fenced + assert sql == sql_text + + +def test_sql_using_llm_requires_schema_and_allows_missing_sql(monkeypatch) -> None: + class DummyCursor: + pass + + with pytest.raises(RuntimeError, match="Choose a schema and try again."): + sql_using_llm(cast(Any, DummyCursor()), question="test", dbname="") + + monkeypatch.setattr(llm_module, "get_schema", lambda cur, dbname, truncate: "schema") + monkeypatch.setattr(llm_module, "get_sample_data", lambda cur, dbname, field_truncate, section_truncate: {"t": [("c", 1)]}) + monkeypatch.setattr(llm_module.click, "echo", lambda message: None) + monkeypatch.setattr(llm_module, "run_external_cmd", lambda *args, **kwargs: (0, "No fenced SQL here.")) + + result, sql = sql_using_llm(cast(Any, DummyCursor()), question="test", dbname="mysql") + + assert result == "No fenced SQL here." + assert sql == "" + + +# Test handle_llm supports registered command names without args +@pytest.mark.parametrize("prefix", [r"\llm", r"\ai"]) +def test_handle_llm_registered_aliases_without_args(prefix, executor, monkeypatch): + assert prefix in COMMANDS + assert COMMANDS[prefix].handler is COMMANDS[r"\llm"].handler + assert COMMANDS[prefix].command == r"\llm" + monkeypatch.setattr(llm_module, "llm", object()) + with pytest.raises(llm_module.FinishIteration) as exc_info: + handle_llm(prefix, executor, 'mysql', 0, 0) + assert exc_info.value.results == [SQLResult(preamble=USAGE)] diff --git a/test/pytests/test_special_main.py b/test/pytests/test_special_main.py new file mode 100644 index 00000000..3c1b2e77 --- /dev/null +++ b/test/pytests/test_special_main.py @@ -0,0 +1,455 @@ +import builtins +from collections.abc import Iterator +import importlib +import importlib.util +import sys +from types import ModuleType +from typing import Any, cast + +import pytest + +from mycli.constants import DOCS_URL, ISSUES_URL +from mycli.packages.special import main as special_main +from mycli.packages.sqlresult import SQLResult + + +@pytest.fixture +def restore_commands() -> Iterator[None]: + original_commands = special_main.COMMANDS.copy() + original_case_sensitive_commands = special_main.CASE_SENSITIVE_COMMANDS.copy() + original_case_insensitive_commands = special_main.CASE_INSENSITIVE_COMMANDS.copy() + try: + yield + finally: + special_main.COMMANDS.clear() + special_main.COMMANDS.update(original_commands) + special_main.CASE_SENSITIVE_COMMANDS.clear() + special_main.CASE_SENSITIVE_COMMANDS.update(original_case_sensitive_commands) + special_main.CASE_INSENSITIVE_COMMANDS.clear() + special_main.CASE_INSENSITIVE_COMMANDS.update(original_case_insensitive_commands) + + +class FakeHelpCursor: + def __init__(self, responses: list[dict[str, Any]]) -> None: + self._responses = responses + self.calls: list[tuple[str, object]] = [] + self.description: list[tuple[str, object | None]] | None = None + self.rowcount = 0 + + def execute(self, query: str, params: object) -> None: + self.calls.append((query, params)) + response = self._responses.pop(0) + self.description = response['description'] + self.rowcount = response['rowcount'] + + +def load_isolated_special_main(module_name: str) -> ModuleType: + assert special_main.__file__ is not None + spec = importlib.util.spec_from_file_location(module_name, special_main.__file__) + assert spec is not None + assert spec.loader is not None + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + try: + spec.loader.exec_module(module) + except Exception: + sys.modules.pop(module_name, None) + raise + return module + + +@pytest.mark.parametrize( + ('sql', 'expected'), + [ + ('help select', ('help', special_main.CommandVerbosity.NORMAL, 'select')), + (r'\llm+ prompt', (r'\llm', special_main.CommandVerbosity.VERBOSE, 'prompt')), + (r'\llm- prompt', (r'\llm', special_main.CommandVerbosity.SUCCINCT, 'prompt')), + ('help spaced ', ('help', special_main.CommandVerbosity.NORMAL, 'spaced')), + ], +) +def test_parse_special_command(sql: str, expected: tuple[str, special_main.CommandVerbosity, str]) -> None: + assert special_main.parse_special_command(sql) == expected + + +def test_register_special_command_adds_primary_and_alias_entries(restore_commands: None) -> None: + def handler() -> None: + return None + + special_main.COMMANDS.clear() + special_main.register_special_command( + handler, + 'Demo', + 'demo', + 'Description', + aliases=[special_main.SpecialCommandAlias('\\d', case_sensitive=False)], + ) + + assert special_main.COMMANDS['demo'] == special_main.SpecialCommand( + handler, + 'Demo', + 'demo', + 'Description', + arg_type=special_main.ArgType.PARSED_QUERY, + hidden=False, + case_sensitive=False, + aliases=[special_main.SpecialCommandAlias('\\d', case_sensitive=False)], + ) + assert special_main.COMMANDS['\\d'] == special_main.SpecialCommand( + handler, + 'Demo', + 'demo', + 'Description', + arg_type=special_main.ArgType.PARSED_QUERY, + hidden=True, + case_sensitive=False, + aliases=None, + ) + + +def test_register_special_command_tracks_case_insensitive_commands(restore_commands: None) -> None: + special_main.COMMANDS.clear() + special_main.CASE_SENSITIVE_COMMANDS.clear() + special_main.CASE_INSENSITIVE_COMMANDS.clear() + + special_main.register_special_command( + lambda: None, + 'Demo', + 'demo', + 'Description', + aliases=[special_main.SpecialCommandAlias('\\d', case_sensitive=False)], + ) + + assert special_main.CASE_SENSITIVE_COMMANDS == set() + assert special_main.CASE_INSENSITIVE_COMMANDS == {'demo', '\\d'} + + +def test_special_command_decorator_registers_case_sensitive_command(restore_commands: None) -> None: + special_main.COMMANDS.clear() + special_main.CASE_SENSITIVE_COMMANDS.clear() + special_main.CASE_INSENSITIVE_COMMANDS.clear() + + @special_main.special_command('Camel', 'Camel', 'Description', case_sensitive=True) + def handler() -> None: + return None + + assert special_main.COMMANDS['Camel'].handler is handler + assert 'Camel' in special_main.CASE_SENSITIVE_COMMANDS + assert special_main.CASE_INSENSITIVE_COMMANDS == set() + assert 'camel' not in special_main.COMMANDS + + +def test_execute_raises_when_command_is_missing() -> None: + with pytest.raises(special_main.CommandNotFound, match='Command not found: missing'): + special_main.execute(cast(Any, None), 'missing') + + +def test_execute_raises_for_case_sensitive_mismatch(restore_commands: None) -> None: + special_main.COMMANDS.clear() + special_main.register_special_command(lambda: None, 'Camel', 'Camel', 'Description', case_sensitive=True) + + with pytest.raises(special_main.CommandNotFound, match='Command not found: camel'): + special_main.execute(cast(Any, None), 'camel') + + +def test_execute_raises_for_case_sensitive_alias_lookup(restore_commands: None) -> None: + special_main.COMMANDS.clear() + special_main.register_special_command( + lambda: None, + 'Demo', + 'Demo', + 'Description', + case_sensitive=True, + aliases=[special_main.SpecialCommandAlias('demo', case_sensitive=True)], + ) + + with pytest.raises(special_main.CommandNotFound, match='Command not found: DEMO'): + special_main.execute(cast(Any, None), 'DEMO') + + +def test_execute_raises_when_case_sensitive_exact_lookup_falls_back_to_lowercase(restore_commands: None) -> None: + special_main.COMMANDS.clear() + special_main.CASE_SENSITIVE_COMMANDS.clear() + special_main.CASE_INSENSITIVE_COMMANDS.clear() + special_main.COMMANDS['camel'] = special_main.SpecialCommand( + lambda: None, + 'Camel', + 'Camel', + 'Description', + arg_type=special_main.ArgType.NO_QUERY, + hidden=False, + case_sensitive=True, + aliases=None, + ) + special_main.CASE_SENSITIVE_COMMANDS.add('Camel') + + with pytest.raises(special_main.CommandNotFound, match='Command not found: Camel'): + special_main.execute(cast(Any, None), 'Camel') + + +def test_execute_dispatches_no_query_command(restore_commands: None) -> None: + calls: list[str] = [] + + def handler() -> list[SQLResult]: + calls.append('called') + return [SQLResult(status='ok')] + + special_main.COMMANDS.clear() + special_main.register_special_command( + handler, + 'demo', + 'demo', + 'Description', + arg_type=special_main.ArgType.NO_QUERY, + ) + + assert special_main.execute(cast(Any, None), 'demo') == [SQLResult(status='ok')] + assert calls == ['called'] + + +def test_execute_uses_lowercase_lookup_for_case_insensitive_command(restore_commands: None) -> None: + calls: list[str] = [] + + def handler() -> list[SQLResult]: + calls.append('called') + return [SQLResult(status='ok')] + + special_main.COMMANDS.clear() + special_main.register_special_command( + handler, + 'demo', + 'demo', + 'Description', + arg_type=special_main.ArgType.NO_QUERY, + ) + + assert special_main.execute(cast(Any, None), 'DEMO') == [SQLResult(status='ok')] + assert calls == ['called'] + + +def test_execute_dispatches_parsed_query_command(restore_commands: None) -> None: + calls: list[tuple[object, str, bool]] = [] + + def handler(*, cur: object, arg: str, command_verbosity: bool) -> list[SQLResult]: + calls.append((cur, arg, command_verbosity)) + return [SQLResult(status='parsed')] + + special_main.COMMANDS.clear() + special_main.register_special_command( + handler, + 'demo', + 'demo', + 'Description', + arg_type=special_main.ArgType.PARSED_QUERY, + ) + + cur = object() + assert special_main.execute(cast(Any, cur), 'demo+ value') == [SQLResult(status='parsed')] + assert calls == [(cur, 'value', True)] + + +def test_execute_dispatches_raw_query_command(restore_commands: None) -> None: + calls: list[tuple[object, str]] = [] + + def handler(*, cur: object, query: str) -> list[SQLResult]: + calls.append((cur, query)) + return [SQLResult(status='raw')] + + special_main.COMMANDS.clear() + special_main.register_special_command( + handler, + 'demo', + 'demo', + 'Description', + arg_type=special_main.ArgType.RAW_QUERY, + case_sensitive=True, + ) + + cur = object() + assert special_main.execute(cast(Any, cur), 'demo payload') == [SQLResult(status='raw')] + assert calls == [(cur, 'demo payload')] + + +def test_execute_routes_help_with_argument_to_keyword_help(monkeypatch) -> None: + calls: list[tuple[object, str]] = [] + + def fake_show_keyword_help(cur: object, arg: str) -> list[SQLResult]: + calls.append((cur, arg)) + return [SQLResult(status='keyword')] + + monkeypatch.setattr(special_main, 'show_keyword_help', fake_show_keyword_help) + + cur = object() + assert special_main.execute(cast(Any, cur), 'help select') == [SQLResult(status='keyword')] + assert calls == [(cur, 'select')] + + +def test_execute_routes_uppercase_help_with_argument_to_keyword_help(monkeypatch) -> None: + calls: list[tuple[object, str]] = [] + + def fake_show_keyword_help(cur: object, arg: str) -> list[SQLResult]: + calls.append((cur, arg)) + return [SQLResult(status='keyword')] + + monkeypatch.setattr(special_main, 'show_keyword_help', fake_show_keyword_help) + + cur = object() + assert special_main.execute(cast(Any, cur), 'HELP select') == [SQLResult(status='keyword')] + assert calls == [(cur, 'select')] + + +def test_execute_raises_for_unknown_arg_type(restore_commands: None) -> None: + special_main.COMMANDS.clear() + special_main.CASE_SENSITIVE_COMMANDS.clear() + special_main.CASE_INSENSITIVE_COMMANDS.clear() + special_main.COMMANDS['demo'] = special_main.SpecialCommand( + lambda: None, + 'demo', + 'demo', + 'Description', + arg_type=cast(Any, object()), + hidden=False, + case_sensitive=False, + aliases=None, + ) + special_main.CASE_INSENSITIVE_COMMANDS.add('demo') + + with pytest.raises(special_main.CommandNotFound, match='Command type not found: demo'): + special_main.execute(cast(Any, None), 'demo') + + +def test_show_help_lists_only_visible_commands(restore_commands: None) -> None: + special_main.COMMANDS.clear() + special_main.register_special_command( + lambda: None, + 'visible', + 'visible', + 'Visible command', + aliases=[special_main.SpecialCommandAlias('\\v', case_sensitive=False)], + ) + special_main.register_special_command(lambda: None, 'hidden', 'hidden', 'Hidden command', hidden=True) + + result = special_main.show_help()[0] + + assert result.header == ['Command', 'Shortcut', 'Usage', 'Description'] + assert result.rows == [('visible', '\\v', 'visible', 'Visible command')] + assert result.postamble == f'Docs index — {DOCS_URL}' + + +def test_show_keyword_help_for_special_command(restore_commands: None) -> None: + special_main.COMMANDS.clear() + special_main.CASE_SENSITIVE_COMMANDS.clear() + special_main.CASE_INSENSITIVE_COMMANDS.clear() + special_main.register_special_command(lambda: None, 'demo', 'demo ', 'Demo command') + + result = special_main.show_keyword_help(cast(Any, None), 'demo+')[0] + + assert result.header == ['name', 'description', 'example'] + assert result.rows == [('demo', 'demo \nDemo command', '')] + + +def test_show_keyword_help_for_case_sensitive_special_alias() -> None: + result = special_main.show_keyword_help(cast(Any, None), r'\e')[0] + + assert result.header == ['name', 'description', 'example'] + assert result.rows == [ + ( + r'\e', + '\\edit | \\edit \nEdit query with editor (uses $VISUAL or $EDITOR).', + '', + ) + ] + + +def test_show_keyword_help_exact_match() -> None: + cur = FakeHelpCursor([ + {'description': [('name', None)], 'rowcount': 1}, + ]) + + result = special_main.show_keyword_help(cast(Any, cur), '"select"')[0] + + assert cur.calls == [('help %s', 'select')] + assert result.header == ['name'] + assert cast(Any, result.rows) is cur + + +def test_show_keyword_help_similar_match() -> None: + cur = FakeHelpCursor([ + {'description': None, 'rowcount': 0}, + {'description': [('name', None)], 'rowcount': 2}, + ]) + + result = special_main.show_keyword_help(cast(Any, cur), "'select'")[0] + + assert cur.calls == [('help %s', 'select'), ('help %s', ('%select%',))] + assert result.preamble == 'Similar terms:' + assert result.header == ['name'] + assert cast(Any, result.rows) is cur + + +def test_show_keyword_help_no_match() -> None: + cur = FakeHelpCursor([ + {'description': None, 'rowcount': 0}, + {'description': None, 'rowcount': 0}, + ]) + + result = special_main.show_keyword_help(cast(Any, cur), 'missing')[0] + + assert result.status == 'No help found for "missing".' + + +def test_file_bug_opens_browser(monkeypatch) -> None: + calls: list[str] = [] + monkeypatch.setattr(special_main.webbrowser, 'open_new_tab', lambda url: calls.append(url)) + + result = special_main.file_bug()[0] + + assert calls == [ISSUES_URL] + assert result.status == f'{ISSUES_URL} — press "New Issue"' + + +def test_quit_command_raises_eoferror() -> None: + with pytest.raises(EOFError): + special_main.quit_() + + +def test_stub_command_raises_not_implemented() -> None: + with pytest.raises(NotImplementedError): + special_main.stub() + + +def test_llm_stub_raises_not_implemented_when_present() -> None: + if hasattr(special_main, 'llm_stub'): + with pytest.raises(NotImplementedError): + special_main.llm_stub() + + +def test_reload_special_main_without_llm_support(monkeypatch) -> None: + with monkeypatch.context() as m: + m.setenv('MYCLI_LLM_OFF', '1') + isolated_main = load_isolated_special_main('test_special_main_without_llm') + try: + assert isolated_main.LLM_IMPORTED is False + assert r'\llm' not in isolated_main.COMMANDS + assert r'\ai' not in isolated_main.COMMANDS + finally: + sys.modules.pop('test_special_main_without_llm', None) + + +def test_reload_special_main_handles_llm_import_error(monkeypatch) -> None: + original_import = builtins.__import__ + + def fake_import(name, globals=None, locals=None, fromlist=(), level=0): # noqa: A002 + if name == 'llm': + raise ImportError('no llm') + return original_import(name, globals, locals, fromlist, level) + + with monkeypatch.context() as m: + m.delenv('MYCLI_LLM_OFF', raising=False) + m.setattr(builtins, '__import__', fake_import) + isolated_main = load_isolated_special_main('test_special_main_import_error') + try: + assert isolated_main.LLM_IMPORTED is False + assert r'\llm' not in isolated_main.COMMANDS + assert r'\ai' not in isolated_main.COMMANDS + finally: + sys.modules.pop('test_special_main_import_error', None) diff --git a/test/pytests/test_special_utils.py b/test/pytests/test_special_utils.py new file mode 100644 index 00000000..efea02df --- /dev/null +++ b/test/pytests/test_special_utils.py @@ -0,0 +1,279 @@ +# type: ignore + +import os +import pathlib +import tempfile +from unittest.mock import MagicMock + +import pymysql +import pytest + +import mycli.packages.special.utils +from mycli.packages.special.utils import ( + CACHED_SSL_VERSION, + format_uptime, + get_local_timezone, + get_server_timezone, + get_ssl_cipher, + get_ssl_version, + get_uptime, + get_warning_count, + handle_cd_command, +) +from test.utils import TEMPFILE_PREFIX + + +@pytest.fixture(autouse=True) +def clear_ssl_cache() -> None: + CACHED_SSL_VERSION.clear() + + +def test_handle_cd_command_rejects_non_cd_command() -> None: + handled, message = handle_cd_command(['pwd']) + + assert handled is False + assert message == 'Not a cd command.' + + +def test_handle_cd_command_requires_exactly_one_directory() -> None: + handled, message = handle_cd_command(['cd']) + + assert handled is False + assert message == 'Exactly one directory name must be provided.' + + +def test_handle_cd_command_changes_directory_and_echoes_cwd(monkeypatch) -> None: + echoed = [] + + monkeypatch.setattr(mycli.packages.special.utils.click, 'echo', lambda message, err=False: echoed.append((message, err))) + monkeypatch.chdir(os.getcwd()) + + # resolve() is needed for mac /private/var arrangement + with tempfile.TemporaryDirectory(prefix=TEMPFILE_PREFIX) as tempdir: + tempdir_resolved = str(pathlib.Path(tempdir).resolve()) + handled, message = handle_cd_command(['cd', tempdir_resolved]) + assert str(pathlib.Path(os.getcwd()).resolve()) == tempdir_resolved + assert handled is True + assert message is None + assert echoed == [(tempdir_resolved, True)] + + +def test_handle_cd_command_returns_oserror_message(monkeypatch) -> None: + def raise_oserror(directory: str) -> None: + raise OSError(2, 'No such file or directory') + + monkeypatch.setattr(mycli.packages.special.utils.os, 'chdir', raise_oserror) + + handled, message = handle_cd_command(['cd', '/missing']) + + assert handled is False + assert message == 'No such file or directory' + + +def test_format_uptime(): + seconds = 59 + assert '59 sec' == format_uptime(seconds) + + seconds = 120 + assert '2 min 0 sec' == format_uptime(seconds) + + seconds = 54890 + assert '15 hours 14 min 50 sec' == format_uptime(seconds) + + seconds = 598244 + assert '6 days 22 hours 10 min 44 sec' == format_uptime(seconds) + + seconds = 522600 + assert '6 days 1 hour 10 min 0 sec' == format_uptime(seconds) + + +def test_format_uptime_uses_singular_units() -> None: + assert format_uptime('90061') == '1 day 1 hour 1 min 1 sec' + + +def test_get_uptime_returns_value_from_status_row() -> None: + cur = MagicMock() + cur.fetchone.return_value = ('Uptime', '15') + + uptime = get_uptime(cur) + + cur.execute.assert_called_once_with('SHOW STATUS LIKE "Uptime"') + assert uptime == 15 + + +def test_get_uptime_defaults_to_zero_for_missing_value() -> None: + cur = MagicMock() + cur.fetchone.return_value = ('Uptime', None) + + assert get_uptime(cur) == 0 + + +def test_get_uptime_ignores_operational_error() -> None: + cur = MagicMock() + cur.execute.side_effect = pymysql.err.OperationalError() + + assert get_uptime(cur) == 0 + + +def test_get_warning_count_returns_value_from_count_row() -> None: + cur = MagicMock() + cur.fetchone.return_value = ('7',) + + warning_count = get_warning_count(cur) + + cur.execute.assert_called_once_with('SHOW COUNT(*) WARNINGS') + assert warning_count == 7 + + +def test_get_warning_count_defaults_to_zero_for_missing_value() -> None: + cur = MagicMock() + cur.fetchone.return_value = (None,) + + assert get_warning_count(cur) == 0 + + +def test_get_warning_count_ignores_operational_error() -> None: + cur = MagicMock() + cur.execute.side_effect = pymysql.err.OperationalError() + + assert get_warning_count(cur) == 0 + + +def test_get_ssl_version_fetches_and_caches_value() -> None: + cur = MagicMock() + cur.connection = MagicMock() + cur.connection.thread_id.return_value = 42 + cur.fetchone.return_value = ('Ssl_version', 'TLSv1.3') + + first = get_ssl_version(cur) + second = get_ssl_version(cur) + + cur.execute.assert_called_once_with('SHOW STATUS LIKE "Ssl_version"') + assert first == 'TLSv1.3' + assert second == 'TLSv1.3' + + +def test_get_ssl_version_caches_missing_row_as_none() -> None: + cur = MagicMock() + cur.connection = MagicMock() + cur.connection.thread_id.return_value = 42 + cur.fetchone.return_value = None + + first = get_ssl_version(cur) + second = get_ssl_version(cur) + + cur.execute.assert_called_once_with('SHOW STATUS LIKE "Ssl_version"') + assert first is None + assert second is None + + +def test_get_ssl_version_returns_none_for_empty_value_and_caches_it() -> None: + cur = MagicMock() + cur.connection = MagicMock() + cur.connection.thread_id.return_value = 42 + cur.fetchone.return_value = ('Ssl_version', '') + + first = get_ssl_version(cur) + second = get_ssl_version(cur) + + cur.execute.assert_called_once_with('SHOW STATUS LIKE "Ssl_version"') + assert first is None + assert second is None + + +def test_get_ssl_version_ignores_operational_error() -> None: + cur = MagicMock() + cur.connection = MagicMock() + cur.connection.thread_id.return_value = 42 + cur.execute.side_effect = pymysql.err.OperationalError() + + assert get_ssl_version(cur) is None + + +def test_get_ssl_cipher_returns_value() -> None: + cur = MagicMock() + cur.fetchone.return_value = ('Ssl_cipher', 'TLS_AES_256_GCM_SHA384') + + ssl_cipher = get_ssl_cipher(cur) + + cur.execute.assert_called_once_with('SHOW STATUS LIKE "Ssl_cipher"') + assert ssl_cipher == 'TLS_AES_256_GCM_SHA384' + + +def test_get_ssl_cipher_returns_none_for_missing_row() -> None: + cur = MagicMock() + cur.fetchone.return_value = None + + assert get_ssl_cipher(cur) is None + + +def test_get_ssl_cipher_returns_none_for_empty_value() -> None: + cur = MagicMock() + cur.fetchone.return_value = ('Ssl_cipher', '') + + assert get_ssl_cipher(cur) is None + + +def test_get_ssl_cipher_ignores_operational_error() -> None: + cur = MagicMock() + cur.execute.side_effect = pymysql.err.OperationalError() + + assert get_ssl_cipher(cur) is None + + +def test_get_server_timezone_prefers_system_timezone_when_requested() -> None: + variables = { + 'time_zone': 'SYSTEM', + 'system_time_zone': 'UTC', + } + + assert get_server_timezone(variables) == 'UTC' + + +def test_get_server_timezone_returns_explicit_timezone() -> None: + variables = { + 'time_zone': '+02:00', + 'system_time_zone': 'UTC', + } + + assert get_server_timezone(variables) == '+02:00' + + +def test_get_server_timezone_returns_empty_string_when_keys_are_missing() -> None: + assert get_server_timezone({}) == '' + + +def test_get_local_timezone_returns_tzname(monkeypatch) -> None: + class FakeAwareDatetime: + def tzname(self) -> str: + return 'EDT' + + class FakeDatetime: + @staticmethod + def now() -> 'FakeDatetime': + return FakeDatetime() + + def astimezone(self) -> FakeAwareDatetime: + return FakeAwareDatetime() + + monkeypatch.setattr(mycli.packages.special.utils.datetime, 'datetime', FakeDatetime) + + assert get_local_timezone() == 'EDT' + + +def test_get_local_timezone_returns_empty_string_when_tzname_is_none(monkeypatch) -> None: + class FakeAwareDatetime: + def tzname(self) -> None: + return None + + class FakeDatetime: + @staticmethod + def now() -> 'FakeDatetime': + return FakeDatetime() + + def astimezone(self) -> FakeAwareDatetime: + return FakeAwareDatetime() + + monkeypatch.setattr(mycli.packages.special.utils.datetime, 'datetime', FakeDatetime) + + assert get_local_timezone() == '' diff --git a/test/pytests/test_sql_utils.py b/test/pytests/test_sql_utils.py new file mode 100644 index 00000000..1be26ef1 --- /dev/null +++ b/test/pytests/test_sql_utils.py @@ -0,0 +1,686 @@ +# type: ignore + +from types import SimpleNamespace + +import pytest +import sqlparse +from sqlparse.sql import Identifier, IdentifierList, Token, TokenList +from sqlparse.tokens import DML, Keyword, Punctuation + +from mycli.packages import sql_utils +from mycli.packages.sql_utils import ( + extract_columns_from_select, + extract_from_part, + extract_table_identifiers, + extract_tables, + extract_tables_from_complete_statements, + find_prev_keyword, + get_last_select, + is_destructive, + is_dropping_database, + is_mutating, + is_select, + is_subselect, + last_word, + need_completion_refresh, + need_completion_reset, + queries_start_with, + query_has_where_clause, + query_is_single_table_update, + query_starts_with, +) + + +def test_extract_columns_from_select(): + columns = extract_columns_from_select('SELECT COLUMN_NAME, DATA_TYPE, IS_NULLABLE, COLUMN_DEFAULT FROM INFORMATION_SCHEMA.COLUMNS') + assert columns == ['COLUMN_NAME', 'DATA_TYPE', 'IS_NULLABLE', 'COLUMN_DEFAULT'] + + +def test_extract_columns_from_select_empty(): + columns = extract_columns_from_select('') + assert columns == [] + + +def test_extract_columns_from_select_update(): + columns = extract_columns_from_select('UPDATE table SET value = 1 WHERE id = 1') + assert columns == [] + + +def test_empty_string(): + tables = extract_tables('') + assert tables == [] + + +def test_simple_select_single_table(): + tables = extract_tables('select * from abc') + assert tables == [(None, 'abc', None)] + + +def test_simple_select_single_table_schema_qualified(): + tables = extract_tables('select * from abc.def') + assert tables == [('abc', 'def', None)] + + +def test_simple_select_multiple_tables(): + tables = extract_tables('select * from abc, def') + assert sorted(tables) == [(None, 'abc', None), (None, 'def', None)] + + +def test_simple_select_multiple_tables_schema_qualified(): + tables = extract_tables('select * from abc.def, ghi.jkl') + assert sorted(tables) == [('abc', 'def', None), ('ghi', 'jkl', None)] + + +def test_simple_select_with_cols_single_table(): + tables = extract_tables('select a,b from abc') + assert tables == [(None, 'abc', None)] + + +def test_simple_select_with_cols_single_table_schema_qualified(): + tables = extract_tables('select a,b from abc.def') + assert tables == [('abc', 'def', None)] + + +def test_simple_select_with_cols_multiple_tables(): + tables = extract_tables('select a,b from abc, def') + assert sorted(tables) == [(None, 'abc', None), (None, 'def', None)] + + +def test_simple_select_with_cols_multiple_tables_with_schema(): + tables = extract_tables('select a,b from abc.def, def.ghi') + assert sorted(tables) == [('abc', 'def', None), ('def', 'ghi', None)] + + +def test_select_with_hanging_comma_single_table(): + tables = extract_tables('select a, from abc') + assert tables == [(None, 'abc', None)] + + +def test_select_with_hanging_comma_multiple_tables(): + tables = extract_tables('select a, from abc, def') + assert sorted(tables) == [(None, 'abc', None), (None, 'def', None)] + + +def test_select_with_hanging_period_multiple_tables(): + tables = extract_tables('SELECT t1. FROM tabl1 t1, tabl2 t2') + assert sorted(tables) == [(None, 'tabl1', 't1'), (None, 'tabl2', 't2')] + + +def test_simple_insert_single_table(): + tables = extract_tables('insert into abc (id, name) values (1, "def")') + + # sqlparse mistakenly assigns an alias to the table + # assert tables == [(None, 'abc', None)] + assert tables == [(None, 'abc', 'abc')] + + +def test_simple_insert_single_table_schema_qualified(): + tables = extract_tables('insert into abc.def (id, name) values (1, "def")') + assert tables == [('abc', 'def', None)] + + +def test_simple_update_table(): + tables = extract_tables('update abc set id = 1') + assert tables == [(None, 'abc', None)] + + +def test_simple_update_table_with_schema(): + tables = extract_tables('update abc.def set id = 1') + assert tables == [('abc', 'def', None)] + + +def test_join_table(): + tables = extract_tables('SELECT * FROM abc a JOIN def d ON a.id = d.num') + assert sorted(tables) == [(None, 'abc', 'a'), (None, 'def', 'd')] + + +def test_join_table_schema_qualified(): + tables = extract_tables('SELECT * FROM abc.def x JOIN ghi.jkl y ON x.id = y.num') + assert tables == [('abc', 'def', 'x'), ('ghi', 'jkl', 'y')] + + +def test_join_as_table(): + tables = extract_tables('SELECT * FROM my_table AS m WHERE m.a > 5') + assert tables == [(None, 'my_table', 'm')] + + +def test_extract_tables_from_complete_statements(): + tables = extract_tables_from_complete_statements('SELECT * FROM my_table AS m WHERE m.a > 5') + assert tables == [(None, 'my_table', 'm')] + + +def test_extract_tables_from_complete_statements_cte(): + tables = extract_tables_from_complete_statements('WITH my_cte (id, num) AS ( SELECT id, COUNT(1) FROM my_table GROUP BY id ) SELECT *') + assert tables == [(None, 'my_table', None)] + + +# this would confuse plain extract_tables() per #1122 +def test_extract_tables_from_multiple_complete_statements(): + tables = extract_tables_from_complete_statements(r'\T sql-insert; SELECT * FROM my_table AS m WHERE m.a > 5') + assert tables == [(None, 'my_table', 'm')] + + +def test_query_starts_with(): + query = 'USE test;' + assert query_starts_with(query, ('use',)) is True + + query = 'DROP DATABASE test;' + assert query_starts_with(query, ('use',)) is False + + +def test_query_starts_with_comment(): + query = '# comment\nUSE test;' + assert query_starts_with(query, ('use',)) is True + + +def test_queries_start_with(): + sql = '# comment\nshow databases;use foo;' + assert queries_start_with(sql, ['show', 'select']) is True + assert queries_start_with(sql, ['use', 'drop']) is True + assert queries_start_with(sql, ['delete', 'update']) is False + + +@pytest.mark.parametrize( + ('text', 'include', 'expected'), + [ + ('abc', 'alphanum_underscore', 'abc'), + (' abc', 'alphanum_underscore', 'abc'), + ('', 'alphanum_underscore', ''), + (' ', 'alphanum_underscore', ''), + ('abc ', 'alphanum_underscore', ''), + ('abc def', 'alphanum_underscore', 'def'), + ('abc def ', 'alphanum_underscore', ''), + ('abc def;', 'alphanum_underscore', ''), + ('bac $def', 'alphanum_underscore', 'def'), + ('bac $def', 'most_punctuations', '$def'), + (r'bac \def', 'most_punctuations', r'\def'), + (r'bac \def;', 'most_punctuations', r'\def;'), + ('bac::def', 'most_punctuations', 'def'), + ('abc:def', 'many_punctuations', 'def'), + ('abc.def', 'all_punctuations', 'abc.def'), + ], +) +def test_last_word(text, include, expected): + assert last_word(text, include=include) == expected + + +def test_is_subselect_returns_false_for_non_group_token(): + token = sqlparse.parse('foo')[0].tokens[0] + assert is_subselect(token) is False + + +def test_is_subselect_returns_false_for_group_without_dml(): + token = sqlparse.parse('(foo)')[0].tokens[0] + assert is_subselect(token) is False + + +def test_is_subselect_returns_true_for_group_with_select(): + token = sqlparse.parse('(select 1)')[0].tokens[0] + assert is_subselect(token) is True + + +def test_get_last_select_returns_empty_token_list_without_select(): + parsed = sqlparse.parse('update t set x = 1')[0] + assert list(get_last_select(parsed).flatten()) == [] + + +def test_get_last_select_returns_single_select_statement(): + parsed = sqlparse.parse('select c1')[0] + tokens = get_last_select(parsed) + assert ''.join(token.value for token in tokens.flatten()) == 'select c1' + + +def test_get_last_select_returns_single_select_statement_with_from(): + parsed = sqlparse.parse('select c1 from')[0] + tokens = get_last_select(parsed) + assert ''.join(token.value for token in tokens.flatten()) == 'select c1 from' + + +def test_get_last_select_returns_last_top_level_select(): + parsed = sqlparse.parse('select c1 union select c2')[0] + tokens = get_last_select(parsed) + assert ''.join(token.value for token in tokens.flatten()) == 'select c2' + + +def test_get_last_select_keeps_outer_select_for_nested_subselect(): + parsed = sqlparse.parse('select c1 from (select c2')[0] + tokens = get_last_select(parsed) + assert ''.join(token.value for token in tokens.flatten()) == 'select c2' + + +def token_values(tokens): + return [token.value for token in tokens if not getattr(token, 'is_whitespace', False)] + + +# todo: coverage of stop_at_punctuation parameter +def test_extract_from_part_returns_identifier_after_from(): + parsed = sqlparse.parse('select * from abc')[0] + tokens = extract_from_part(parsed) + assert token_values(tokens) == ['abc'] + + +def test_extract_from_part_returns_identifier_list(): + parsed = sqlparse.parse('select * from abc, def')[0] + tokens = extract_from_part(parsed) + assert token_values(tokens) == ['abc, def'] + + +def test_extract_from_part_handles_multiple_joins_and_skips_on_clause(): + parsed = sqlparse.parse('select * from abc join def on abc.id = def.id join ghi')[0] + tokens = extract_from_part(parsed) + assert token_values(tokens) == ['abc', 'join', 'def', 'ghi'] + + +def test_extract_from_part_recurses_into_subselect_and_stops_at_punctuation(): + parsed = sqlparse.parse('select * from (select * from inner_table), outer_table')[0] + tokens = extract_from_part(parsed) + assert token_values(tokens) == ['inner_table'] + + +def test_extract_from_part_stops_at_punctuation_when_requested(): + parsed = TokenList([Token(Keyword, 'FROM'), Token(Punctuation, ','), Token(Keyword, 'SELECT')]) + tokens = extract_from_part(parsed, stop_at_punctuation=True) + assert token_values(tokens) == [] + + +def test_extract_table_identifiers_handles_identifier_list(): + parsed = sqlparse.parse('select * from abc a, def d')[0] + token_stream = extract_from_part(parsed) + assert list(extract_table_identifiers(token_stream)) == [ + (None, 'abc', 'a'), + (None, 'def', 'd'), + ] + + +def test_extract_table_identifiers_handles_schema_qualified_identifier(): + parsed = sqlparse.parse('select * from abc.def x')[0] + token_stream = extract_from_part(parsed) + assert list(extract_table_identifiers(token_stream)) == [('abc', 'def', 'x')] + + +def test_extract_table_identifiers_handles_function_tokens(): + parsed = sqlparse.parse('select * from my_func()')[0] + token_stream = extract_from_part(parsed) + assert list(extract_table_identifiers(token_stream)) == [(None, 'my_func', 'my_func')] + + +def test_extract_table_identifiers_skips_identifier_list_entries_without_identifier_methods(): + class BrokenIdentifierList(IdentifierList): + def get_identifiers(self): + return [object()] + + assert list(extract_table_identifiers(iter([BrokenIdentifierList([])]))) == [] + + +def test_extract_table_identifiers_uses_name_when_identifier_has_no_real_name(): + class NamelessIdentifier(Identifier): + def get_real_name(self): + return None + + def get_parent_name(self): + return None + + def get_name(self): + return 'fallback_name' + + def get_alias(self): + return None + + assert list(extract_table_identifiers(iter([NamelessIdentifier([])]))) == [ + (None, 'fallback_name', 'fallback_name'), + ] + + +@pytest.mark.parametrize( + ('sql', 'expected_keyword', 'expected_text'), + [ + ('', None, ''), + ('foo', None, ''), + ('select * from foo where bar = 1', 'where', 'select * from foo where'), + ('select * from foo where a = 1 and b = 2', 'where', 'select * from foo where'), + ('select * from foo where a between 1 and 2', 'where', 'select * from foo where'), + ('select count(', '(', 'select count('), + ], +) +def test_find_prev_keyword(sql, expected_keyword, expected_text): + token, text = find_prev_keyword(sql) + assert (token.value if token else None) == expected_keyword + assert text == expected_text + + +@pytest.mark.parametrize( + ('sql', 'is_single_table'), + [ + ('update test set x = 1', True), + ('update test t set x = 1', True), + ('update /* inline comment */ test set x = 1', True), + ('select 1', False), + ('', False), + ('update', False), + ('update test, foo set x = 1', False), + ('update test join foo on test.id = foo.id set test.x = 1', False), + ], +) +def test_query_is_single_table_update(sql, is_single_table): + assert query_is_single_table_update(sql) is is_single_table + + +def test_extract_columns_from_select_handles_falsey_last_select(monkeypatch): + monkeypatch.setattr(sql_utils, 'get_last_select', lambda _parsed: []) + assert extract_columns_from_select('select 1') == [] + + +def test_extract_columns_from_select_handles_single_identifier(monkeypatch): + class SingleIdentifier(Identifier): + def get_real_name(self): + return 'column_name' + + monkeypatch.setattr( + sql_utils, + 'get_last_select', + lambda _parsed: TokenList([Token(DML, 'SELECT'), SingleIdentifier([])]), + ) + + assert extract_columns_from_select('select column_name') == ['column_name'] + + +def test_extract_columns_from_select_ignores_unhandled_identifier_list_entries(monkeypatch): + class WeirdIdentifierList(IdentifierList): + def get_identifiers(self): + return [object()] + + monkeypatch.setattr( + sql_utils, + 'get_last_select', + lambda _parsed: TokenList([Token(DML, 'SELECT'), WeirdIdentifierList([])]), + ) + + assert extract_columns_from_select('select 1') == [] + + +def test_extract_columns_from_select_stops_at_keyword_before_collecting_columns(monkeypatch): + monkeypatch.setattr( + sql_utils, + 'get_last_select', + lambda _parsed: TokenList([Token(DML, 'SELECT'), Token(Keyword, 'FROM')]), + ) + + assert extract_columns_from_select('select 1') == [] + + +def test_extract_tables_from_complete_statements_returns_empty_for_falsey_rough_parse(monkeypatch): + monkeypatch.setattr(sql_utils.sqlparse, 'parse', lambda _sql: []) + + assert extract_tables_from_complete_statements('select * from t') == [] + + +def test_extract_tables_from_complete_statements_skips_cte_table_identifiers(monkeypatch): + class FakeParentSelect: + def sql(self): + return 'WITH cte AS (SELECT 1) SELECT * FROM cte' + + class FakeIdentifier: + parent_select = FakeParentSelect() + db = '' + name = 'cte' + alias = '' + + class FakeStatement: + def find_all(self, _table_type): + return [FakeIdentifier()] + + monkeypatch.setattr(sql_utils.sqlparse, 'parse', lambda _sql: ['stmt']) + monkeypatch.setattr(sql_utils.sqlglot, 'parse_one', lambda *_args, **_kwargs: FakeStatement()) + + assert extract_tables_from_complete_statements('with cte as (select 1) select * from cte') == [] + + +def test_query_is_single_table_update_returns_false_when_parse_result_is_empty(monkeypatch): + monkeypatch.setattr(sql_utils.sqlparse, 'parse', lambda _sql: []) + + assert query_is_single_table_update('update test set x = 1') is False + + +def test_is_destructive(): + sql = "use test;\nshow databases;\ndrop database foo;" + assert is_destructive(["drop"], sql) is True + + +def test_is_destructive_update_with_where_clause(): + sql = "use test;\nshow databases;\nUPDATE test SET x = 1 WHERE id = 1;" + assert is_destructive(["update"], sql) is False + + +def test_is_destructive_update_with_where_clause_and_comment(): + sql = "use test;\nshow databases;\nUPDATE /* inline comment */ test SET x = 1 WHERE id = 1;" + assert is_destructive(["update"], sql) is False + + +def test_is_destructive_update_multiple_tables_with_where_clause(): + sql = "use test;\nshow databases;\nUPDATE test, foo SET x = 1 WHERE id = 1;" + assert is_destructive(["update"], sql) is True + + +def test_is_destructive_update_without_where_clause(): + sql = "use test;\nshow databases;\nUPDATE test SET x = 1;" + assert is_destructive(["update"], sql) is True + + +def test_is_destructive_skips_empty_split_queries(monkeypatch): + monkeypatch.setattr(sql_utils.sqlparse, 'split', lambda _queries: ['', '']) + + assert is_destructive(['drop'], 'ignored') is False + + +def test_is_destructive_returns_false_when_no_query_matches_keywords() -> None: + assert is_destructive(['drop'], 'select 1; show databases;') is False + + +@pytest.mark.parametrize( + ("sql", "has_where_clause"), + [ + ("update test set dummy = 1;", False), + ("update test set dummy = 1 where id = 1);", True), + ], +) +def test_query_has_where_clause(sql, has_where_clause): + assert query_has_where_clause(sql) is has_where_clause + + +@pytest.mark.parametrize( + ("sql", "dbname", "is_dropping"), + [ + ("select bar from foo", "foo", False), + ('drop database "foo";', "`foo`", True), + ("drop schema foo", "foo", True), + ("drop schema foo", "bar", False), + ("drop database bar", "foo", False), + ("drop database foo", None, False), + ("drop database foo; create database foo", "foo", False), + ("drop database foo; create database bar", "foo", True), + ("select bar from foo; drop database bazz", "foo", False), + ("select bar from foo; drop database bazz", "bazz", True), + ("-- dropping database \n drop -- really dropping \n schema abc -- now it is dropped", "abc", True), + ], +) +def test_is_dropping_database(sql, dbname, is_dropping): + assert is_dropping_database(sql, dbname) == is_dropping + + +def test_is_dropping_database_skips_statements_without_enough_keywords(): + assert is_dropping_database('drop foo', 'foo') is False + + +@pytest.mark.parametrize( + ('queries', 'expected'), + [ + ('select 1;', False), + ('alter table foo add column bar int;', True), + ('create table foo (id int);', True), + ('use foo;', True), + ('\\r foo localhost root', True), + ('\\u foo', True), + ('connect foo localhost root', True), + ('drop table foo;', True), + ('rename table foo to bar;', True), + ], +) +def test_need_completion_refresh(queries, expected): + assert need_completion_refresh(queries) is expected + + +def test_need_completion_refresh_ignores_queries_that_fail_to_split(monkeypatch): + class BrokenQuery: + def split(self): + raise RuntimeError('broken') + + monkeypatch.setattr(sql_utils.sqlparse, 'split', lambda _queries: [BrokenQuery(), 'select 1;']) + + assert need_completion_refresh('ignored') is False + + +@pytest.mark.parametrize( + ('queries', 'expected'), + [ + ('select 1;', False), + ('use foo;', True), + ('\\u foo', True), + ('\\r', False), + ('\\r foo localhost root', True), + ('connect', False), + ('connect foo localhost root', True), + ], +) +def test_need_completion_reset(queries, expected): + assert need_completion_reset(queries) is expected + + +def test_need_completion_reset_ignores_queries_that_fail_to_split(monkeypatch): + class BrokenQuery: + def split(self): + raise RuntimeError('broken') + + monkeypatch.setattr(sql_utils.sqlparse, 'split', lambda _queries: [BrokenQuery(), 'select 1;']) + + assert need_completion_reset('ignored') is False + + +def test_classify_sandbox_statement_treats_token_error_as_quit(monkeypatch): + def raise_token_error(*_args, **_kwargs): + raise sql_utils.sqlglot.errors.TokenError('bad token') + + monkeypatch.setattr(sql_utils.sqlglot, 'tokenize', raise_token_error) + + assert sql_utils.classify_sandbox_statement('`') == ('quit', None) + + +def test_classify_sandbox_statement_treats_empty_tokens_as_quit(monkeypatch): + monkeypatch.setattr(sql_utils.sqlglot, 'tokenize', lambda *_args, **_kwargs: []) + + assert sql_utils.classify_sandbox_statement('ignored') == ('quit', None) + + +def test_find_password_after_eq_returns_none_for_non_string_token() -> None: + token_type = sql_utils.sqlglot.tokens.TokenType + tokens = [ + SimpleNamespace(token_type=token_type.EQ, text='='), + SimpleNamespace(token_type=token_type.VAR, text='CURRENT_USER'), + ] + + assert sql_utils._find_password_after_eq(tokens) is None + + +@pytest.mark.parametrize( + ('text', 'expected'), + [ + ('', ('quit', None)), + (' ', ('quit', None)), + ('quit', ('quit', None)), + ('exit', ('quit', None)), + ('\\q', ('quit', None)), + ("ALTER USER 'root'@'localhost' IDENTIFIED BY 'new'", ('alter_user', 'new')), + ('ALTER USER root IDENTIFIED WITH mysql_native_password', ('alter_user', None)), + ("SET PASSWORD = 'newpass'", ('set_password', 'newpass')), + ('SELECT 1', (None, None)), + ], +) +def test_classify_sandbox_statement(text: str, expected: tuple[str | None, str | None]) -> None: + assert sql_utils.classify_sandbox_statement(text) == expected + + +@pytest.mark.parametrize( + ('text', 'expected'), + [ + ('', True), + (' ', True), + ("ALTER USER 'root'@'localhost' IDENTIFIED BY 'new'", True), + ('alter user root identified by "pw"', True), + ("SET PASSWORD = 'newpass'", True), + ("set password = 'newpass'", True), + ('quit', True), + ('exit', True), + ('\\q', True), + ('SELECT 1', False), + ('DROP TABLE t', False), + ('USE mydb', False), + ('SHOW DATABASES', False), + ], +) +def test_is_sandbox_allowed(text: str, expected: bool) -> None: + assert sql_utils.is_sandbox_allowed(text) is expected + + +@pytest.mark.parametrize( + ('text', 'expected'), + [ + ("ALTER USER 'root'@'localhost' IDENTIFIED BY 'new'", True), + ("SET PASSWORD = 'newpass'", True), + ('SELECT 1', False), + ('quit', False), + ], +) +def test_is_password_change(text: str, expected: bool) -> None: + assert sql_utils.is_password_change(text) is expected + + +@pytest.mark.parametrize( + ('text', 'expected'), + [ + ("ALTER USER 'root'@'localhost' IDENTIFIED BY 'newpass'", 'newpass'), + ("SET PASSWORD = 'secret123'", 'secret123'), + ("ALTER USER root IDENTIFIED BY 'p@ss w0rd!'", 'p@ss w0rd!'), + ('ALTER USER root IDENTIFIED WITH mysql_native_password', None), + ('SELECT 1', None), + ], +) +def test_extract_new_password(text: str, expected: str | None) -> None: + assert sql_utils.extract_new_password(text) == expected + + +@pytest.mark.parametrize( + ('status_plain', 'expected'), + [ + (None, False), + ('', False), + ('SELECT 1', False), + ('INSERT 1', True), + ('update 3', True), + ('rename table', True), + ], +) +def test_is_mutating(status_plain, expected): + assert is_mutating(status_plain) is expected + + +@pytest.mark.parametrize( + ('status_plain', 'expected'), + [ + (None, False), + ('', False), + ('SELECT 1', True), + ('select rows', True), + ('UPDATE 1', False), + ], +) +def test_is_select(status_plain, expected): + assert is_select(status_plain) is expected diff --git a/test/pytests/test_sqlcompleter.py b/test/pytests/test_sqlcompleter.py new file mode 100644 index 00000000..1b796eba --- /dev/null +++ b/test/pytests/test_sqlcompleter.py @@ -0,0 +1,649 @@ +# type: ignore + +import re +from types import SimpleNamespace + +from prompt_toolkit.document import Document +import pytest + +import mycli.sqlcompleter +from mycli.sqlcompleter import Fuzziness, SQLCompleter + + +def collect_matches( + orig_text: str, + collection: list[str], + *, + start_only: bool = False, + fuzzy: bool = True, + casing: str | None = None, + text_before_cursor: str = '', +) -> list[tuple[str, int]]: + completer = SQLCompleter() + return list( + completer.find_matches( + orig_text, + collection, + start_only=start_only, + fuzzy=fuzzy, + casing=casing, + text_before_cursor=text_before_cursor, + ) + ) + + +def make_completer(**kwargs) -> SQLCompleter: + comp = SQLCompleter(**kwargs) + comp.keywords = list(comp.keywords) + comp.functions = list(comp.functions) + return comp + + +@pytest.mark.parametrize( + ('item', 'expected'), + [ + ('users', '`users`'), + ('`already`', '`already`'), + ('*', '*'), + ], +) +def test_maybe_quote_identifier(item: str, expected: str) -> None: + completer = SQLCompleter() + assert completer.maybe_quote_identifier(item) == expected + + +def test_quote_collection_if_needed_quotes_when_text_starts_with_backtick() -> None: + completer = SQLCompleter() + quoted = completer.quote_collection_if_needed('`us', ['users', '*'], '') + + assert quoted == ['`users`', '*'] + + +def test_quote_collection_if_needed_quotes_when_cursor_is_inside_backticks() -> None: + completer = SQLCompleter() + quoted = completer.quote_collection_if_needed('us', ['users', '`uuid`'], 'select `us') + + assert quoted == ['`users`', '`uuid`'] + + +def test_quote_collection_if_needed_leaves_collection_unchanged_when_not_quoted() -> None: + collection = ['users', '*'] + completer = SQLCompleter() + quoted = completer.quote_collection_if_needed('us', collection, 'select us') + + assert quoted is collection + + +@pytest.mark.parametrize( + ('text_parts', 'item_parts', 'expected'), + [ + (['us', 'de', 'fu'], ['user', 'defined', 'function'], True), + (['us', 'fu'], ['user', 'defined', 'function'], True), + (['us', 'zz'], ['user', 'defined', 'function'], False), + ([], ['user', 'defined', 'function'], True), + (['us'], [], False), + ], +) +def test_word_parts_match( + text_parts: list[str], + item_parts: list[str], + expected: bool, +) -> None: + completer = SQLCompleter() + assert completer.word_parts_match(text_parts, item_parts) is expected + + +@pytest.mark.parametrize( + ('item', 'pattern', 'under_words_text', 'case_words_text', 'expected'), + [ + ('foo_select_bar', re.compile('(s.{0,3}?e.{0,3}?l)'), ['sel'], ['sel'], Fuzziness.REGEX), + ('user_defined_function', re.compile('(z.{0,3}?z)'), ['us', 'de', 'fu'], ['us_de_fu'], Fuzziness.UNDER_WORDS), + ('TimeZoneTransitionType', re.compile('(Ti.{0,3}?Zx)'), ['TiZoTrTy'], ['Ti', 'Zo', 'Tr', 'Ty'], Fuzziness.CAMEL_CASE), + ('orders', re.compile('(z.{0,3}?z)'), ['zz'], ['zz'], None), + ], +) +def test_find_fuzzy_match( + item: str, + pattern: re.Pattern[str], + under_words_text: list[str], + case_words_text: list[str], + expected: int | None, +) -> None: + completer = SQLCompleter() + assert completer.find_fuzzy_match(item, pattern, under_words_text, case_words_text) == expected + + +def test_find_fuzzy_matches_collects_item_level_matches(monkeypatch) -> None: + monkeypatch.setattr( + SQLCompleter, + 'find_fuzzy_match', + lambda self, item, pattern, under_words_text, case_words_text: { + 'orders': Fuzziness.REGEX, + 'order_items': Fuzziness.UNDER_WORDS, + 'other': None, + }[item], + ) + monkeypatch.setattr(mycli.sqlcompleter.rapidfuzz.process, 'extract', lambda *args, **kwargs: []) + completer = SQLCompleter() + matches = completer.find_fuzzy_matches('OrIt', 'orit', ['orders', 'order_items', 'other']) + + assert matches == [ + ('orders', Fuzziness.REGEX), + ('order_items', Fuzziness.UNDER_WORDS), + ] + + +def test_find_fuzzy_matches_skips_rapidfuzz_for_short_text(monkeypatch) -> None: + monkeypatch.setattr(SQLCompleter, 'find_fuzzy_match', lambda *args, **kwargs: None) + + def fail_extract(*args, **kwargs): + raise AssertionError('rapidfuzz should not be called') + + monkeypatch.setattr(mycli.sqlcompleter.rapidfuzz.process, 'extract', fail_extract) + completer = SQLCompleter() + matches = completer.find_fuzzy_matches('sel', 'sel', ['SELECT']) + + assert matches == [] + + +def test_find_fuzzy_matches_appends_rapidfuzz_results_and_skips_duplicates(monkeypatch) -> None: + monkeypatch.setattr( + SQLCompleter, + 'find_fuzzy_match', + lambda self, item, pattern, under_words_text, case_words_text: Fuzziness.REGEX if item == 'alphabet' else None, + ) + monkeypatch.setattr( + mycli.sqlcompleter.rapidfuzz.process, + 'extract', + lambda *args, **kwargs: [('abc', 99, 0), ('alphabet', 95, 1), ('alphanumeric', 90, 2)], + ) + completer = SQLCompleter() + matches = completer.find_fuzzy_matches('alpahet', 'alpahet', ['abc', 'alphabet', 'alphanumeric']) + + assert matches == [ + ('alphabet', Fuzziness.REGEX), + ('alphanumeric', Fuzziness.RAPIDFUZZ), + ] + + +@pytest.mark.parametrize('existing_fuzziness', [Fuzziness.PERFECT, Fuzziness.CAMEL_CASE, Fuzziness.RAPIDFUZZ]) +def test_find_fuzzy_matches_skips_rapidfuzz_duplicates_for_remaining_fuzziness_types( + monkeypatch, + existing_fuzziness: Fuzziness, +) -> None: + monkeypatch.setattr( + SQLCompleter, + 'find_fuzzy_match', + lambda self, item, pattern, under_words_text, case_words_text: existing_fuzziness if item == 'alphabet' else None, + ) + monkeypatch.setattr( + mycli.sqlcompleter.rapidfuzz.process, + 'extract', + lambda *args, **kwargs: [('alphabet', 95, 0)], + ) + completer = SQLCompleter() + + matches = completer.find_fuzzy_matches('alpahet', 'alpahet', ['alphabet']) + + assert matches == [('alphabet', existing_fuzziness)] + + +@pytest.mark.parametrize( + ('text', 'collection', 'start_only', 'expected'), + [ + ('ord', ['orders', 'user_orders'], True, [('orders', Fuzziness.PERFECT)]), + ('name', ['table_name', 'name_table'], False, [('table_name', Fuzziness.PERFECT), ('name_table', Fuzziness.PERFECT)]), + ('', ['orders', 'users'], True, [('orders', Fuzziness.PERFECT), ('users', Fuzziness.PERFECT)]), + ], +) +def test_find_perfect_matches( + text: str, + collection: list[str], + start_only: bool, + expected: list[tuple[str, int]], +) -> None: + completer = SQLCompleter() + assert completer.find_perfect_matches(text, collection, start_only) == expected + + +@pytest.mark.parametrize( + ('casing', 'last', 'expected'), + [ + (None, 'Sel', None), + ('upper', 'sel', 'upper'), + ('lower', 'SEL', 'lower'), + ('auto', 'sel', 'lower'), + ('auto', 'SEl', 'lower'), + ('auto', 'SEL', 'upper'), + ('auto', '', 'upper'), + ], +) +def test_resolve_casing(casing: str | None, last: str, expected: str | None) -> None: + completer = SQLCompleter() + assert completer.resolve_casing(casing, last) == expected + + +@pytest.mark.parametrize( + ('completions', 'casing', 'expected'), + [ + ([('Select', Fuzziness.REGEX)], None, [('Select', Fuzziness.REGEX)]), + ([('Select', Fuzziness.REGEX)], 'upper', [('SELECT', Fuzziness.REGEX)]), + ([('Select', Fuzziness.REGEX)], 'lower', [('select', Fuzziness.REGEX)]), + ( + [('Select', Fuzziness.REGEX), ('From', Fuzziness.PERFECT)], + 'upper', + [('SELECT', Fuzziness.REGEX), ('FROM', Fuzziness.PERFECT)], + ), + ], +) +def test_apply_casing( + completions: list[tuple[str, int]], + casing: str | None, + expected: list[tuple[str, int]], +) -> None: + completer = SQLCompleter() + assert list(completer.apply_casing(completions, casing)) == expected + + +def test_find_matches_uses_last_word_for_prefix_matching() -> None: + matches = collect_matches( + 'select ord', + ['orders', 'user_orders'], + start_only=True, + fuzzy=False, + ) + + assert matches == [('orders', Fuzziness.PERFECT)] + + +def test_find_matches_supports_substring_matching() -> None: + matches = collect_matches( + 'name', + ['table_name', 'name_table'], + start_only=False, + fuzzy=False, + ) + + assert matches == [ + ('table_name', Fuzziness.PERFECT), + ('name_table', Fuzziness.PERFECT), + ] + + +def test_find_matches_quotes_identifiers_when_text_starts_with_backtick() -> None: + matches = collect_matches('`us', ['users']) + + assert matches == [('`users`', Fuzziness.REGEX)] + + +def test_find_matches_quotes_identifiers_when_cursor_is_inside_backticks() -> None: + matches = collect_matches( + 'uu', + ['users', '`uuid`'], + text_before_cursor='select `uu', + ) + + assert matches == [('`uuid`', Fuzziness.REGEX)] + + +def test_find_matches_preserves_asterisk_inside_backticks() -> None: + matches = collect_matches( + '*', + ['*'], + text_before_cursor='select `*', + ) + + assert matches == [('*', Fuzziness.REGEX)] + + +def test_find_matches_finds_regex_matches() -> None: + matches = collect_matches('sel', ['SELECT', 'foo_select_bar']) + + assert matches == [ + ('SELECT', Fuzziness.REGEX), + ('foo_select_bar', Fuzziness.REGEX), + ] + + +def test_find_matches_finds_under_word_matches() -> None: + matches = collect_matches('us_de_fu', ['user_defined_function']) + + assert matches == [('user_defined_function', Fuzziness.UNDER_WORDS)] + + +def test_find_matches_finds_camel_case_matches(monkeypatch) -> None: + monkeypatch.setattr(mycli.sqlcompleter.rapidfuzz.process, 'extract', lambda *args, **kwargs: []) + + matches = collect_matches('TiZoTrTy', ['TimeZoneTransitionType']) + + assert matches == [('TimeZoneTransitionType', Fuzziness.CAMEL_CASE)] + + +def test_find_matches_finds_rapidfuzz_matches() -> None: + matches = collect_matches('sleect', ['SELECT']) + + assert matches == [('SELECT', Fuzziness.RAPIDFUZZ)] + + +def test_find_matches_skips_rapidfuzz_for_short_text(monkeypatch) -> None: + def fail_extract(*args, **kwargs): + raise AssertionError('rapidfuzz should not be called') + + monkeypatch.setattr(mycli.sqlcompleter.rapidfuzz.process, 'extract', fail_extract) + + matches = collect_matches('sel', ['SELECT']) + + assert matches == [('SELECT', Fuzziness.REGEX)] + + +def test_find_matches_filters_short_rapidfuzz_candidates(monkeypatch) -> None: + monkeypatch.setattr( + mycli.sqlcompleter.rapidfuzz.process, + 'extract', + lambda *args, **kwargs: [('abc', 99, 0), ('alphabet', 95, 1)], + ) + + matches = collect_matches('alpahet', ['abc', 'alphabet']) + + assert matches == [('alphabet', Fuzziness.RAPIDFUZZ)] + + +@pytest.mark.parametrize( + ('orig_text', 'collection', 'casing', 'expected'), + [ + ('sel', ['SELECT'], 'auto', [('select', Fuzziness.REGEX)]), + ('SEL', ['select'], 'auto', [('SELECT', Fuzziness.REGEX)]), + ('sel', ['select'], 'upper', [('SELECT', Fuzziness.REGEX)]), + ('SEL', ['SELECT'], 'lower', [('select', Fuzziness.REGEX)]), + ], +) +def test_find_matches_applies_casing( + orig_text: str, + collection: list[str], + casing: str, + expected: list[tuple[str, int]], +) -> None: + matches = collect_matches(orig_text, collection, casing=casing) + + assert matches == expected + + +def test_init_invalid_keyword_casing_defaults_to_auto() -> None: + completer = SQLCompleter(keyword_casing='invalid') + + assert completer.keyword_casing == 'auto' + + +def test_extend_metadata_helpers_and_logging(caplog) -> None: + completer = make_completer() + completer.set_dbname('missing') + + completer.extend_keywords(['ZZZ']) + assert 'ZZZ' in completer.keywords + assert 'ZZZ' in completer.all_completions + + completer.extend_keywords(['ONLY_THIS'], replace=True) + assert completer.keywords == ['ONLY_THIS'] + assert 'ONLY_THIS' in completer.all_completions + + completer.extend_show_items([('FULL TABLES',), ('STATUS',)]) + completer.extend_change_items([('MASTER TO',)]) + completer.extend_users([('app_user',)]) + assert completer.show_items == ['FULL TABLES', 'STATUS'] + assert 'MASTER TO' in completer.change_items + assert 'app_user' in completer.users + + completer.extend_schemata(None) + assert '' not in completer.dbmetadata['tables'] + + with caplog.at_level('ERROR', logger='mycli.sqlcompleter'): + completer.extend_relations([('orders',)], kind='tables') + assert "listed in unrecognized schema 'missing'" in caplog.text + + completer.extend_schemata('test') + completer.set_dbname('test') + completer.extend_relations([('select',)], kind='tables') + + caplog.clear() + with caplog.at_level('ERROR', logger='mycli.sqlcompleter'): + completer.extend_columns([('missing', 'id'), ('select', 'from')], kind='tables') + assert "relname 'missing' was not found in db 'test'" in caplog.text + assert completer.dbmetadata['tables']['test']['`select`'] == ['*', '`from`'] + + completer.set_dbname('enumdb') + completer.extend_enum_values([('order status', 'select', ['pending'])]) + assert completer.dbmetadata['enum_values']['enumdb']['`order status`']['`select`'] == ['pending'] + + +def test_extend_functions_procedures_character_sets_and_collations() -> None: + completer = make_completer() + completer.extend_schemata('test') + completer.set_dbname('test') + + completer.extend_functions(['BUILTIN_X'], builtin=True) + assert 'BUILTIN_X' in completer.functions + + def broken_functions(): + raise RuntimeError('boom') + yield ('ignored', 'ignored') + + completer.extend_functions(broken_functions()) + completer.extend_functions(iter([('quoted func', 'meta')])) + assert '`quoted func`' in completer.dbmetadata['functions']['test'] + + completer.extend_procedures(iter([(), (None,), ('proc_demo',)])) + assert 'proc_demo' in completer.dbmetadata['procedures']['test'] + + completer.extend_character_sets(iter([(), (None,), ('utf8mb4',)])) + completer.extend_collations(iter([(), (None,), ('utf8mb4_unicode_ci',)])) + assert completer.character_sets == ['utf8mb4'] + assert completer.collations == ['utf8mb4_unicode_ci'] + + +def test_extend_procedures_initializes_schema_metadata_when_missing() -> None: + completer = make_completer() + completer.set_dbname('procdb') + + completer.extend_procedures(iter([('proc_demo',)])) + + assert completer.dbmetadata['procedures']['procdb']['proc_demo'] is None + + +def test_get_completions_drop_unique_columns(monkeypatch) -> None: + completer = make_completer() + completer.extend_schemata('test') + completer.set_dbname('test') + completer.dbmetadata['tables']['test'] = { + 't1': ['*', 'id', 'name'], + 't2': ['*', 'id', 'email'], + } + + monkeypatch.setattr( + mycli.sqlcompleter, + 'suggest_type', + lambda text, before: [{'type': 'column', 'tables': [(None, 't1', None), (None, 't2', None)], 'drop_unique': True}], + ) + + result = [c.text for c in completer.get_completions(Document(text='SELECT ', cursor_position=7), None)] + + assert result == ['id'] + + +@pytest.mark.parametrize( + ('suggestion', 'setup', 'text', 'expected'), + [ + ({'type': 'procedure', 'schema': 'test'}, lambda c, m: c.extend_procedures(iter([('proc_demo',)])), 'CALL pro', 'proc_demo'), + ({'type': 'show'}, lambda c, m: c.extend_show_items([('TABLE STATUS',)]), 'SHOW tab', 'table status'), + ({'type': 'change'}, lambda c, m: c.extend_change_items([('MASTER TO',)]), 'CHANGE ma', 'MASTER TO'), + ({'type': 'user'}, lambda c, m: c.extend_users([('app_user',)]), 'GRANT app', 'app_user'), + ( + {'type': 'favoritequery'}, + lambda c, m: m.setattr( + mycli.sqlcompleter.FavoriteQueries, 'instance', SimpleNamespace(list=lambda: ['daily_report']), raising=False + ), + '\\f dai', + 'daily_report', + ), + ({'type': 'table_format'}, lambda c, m: None, 'fmt c', 'csv'), + ], +) +def test_get_completions_branch_specific_suggestions(monkeypatch, suggestion, setup, text, expected) -> None: + completer = make_completer(supported_formats=('csv', 'tsv')) + completer.extend_schemata('test') + completer.set_dbname('test') + setup(completer, monkeypatch) + monkeypatch.setattr(mycli.sqlcompleter, 'suggest_type', lambda full_text, before: [suggestion]) + + result = [c.text for c in completer.get_completions(Document(text=text, cursor_position=len(text)), None)] + + assert expected in result + + +def test_get_completions_llm_branch_with_and_without_current_word(monkeypatch) -> None: + tokens_seen: list[list[str]] = [] + + def fake_get_completions(tokens: list[str]) -> list[str]: + tokens_seen.append(tokens) + return ['chat', 'explain'] + + monkeypatch.setattr(mycli.sqlcompleter, 'suggest_type', lambda full_text, before: [{'type': 'llm'}]) + monkeypatch.setattr(mycli.sqlcompleter.llm, 'get_completions', fake_get_completions) + + completer = make_completer() + + blank_word = [c.text for c in completer.get_completions(Document(text='\\llm ', cursor_position=5), None)] + partial_text = '\\llm ask ch' + partial_word = [c.text for c in completer.get_completions(Document(text=partial_text, cursor_position=len(partial_text)), None)] + + assert tokens_seen == [[], ['ask']] + assert 'chat' in blank_word + assert 'chat' in partial_word + assert 'explain' in blank_word + assert 'explain' not in partial_word + + +def test_find_files_populate_scoped_cols_and_enum_helpers(monkeypatch) -> None: + completer = make_completer() + completer.extend_schemata('test') + completer.set_dbname('test') + completer.dbmetadata['tables']['test']['`select`'] = ['id'] + completer.dbmetadata['views']['test']['orders_view'] = ['view_id'] + completer.extend_enum_values([('orders', 'status', ['pending', 'shipped'])]) + + monkeypatch.setattr(mycli.sqlcompleter, 'parse_path', lambda word: ('/tmp', 'fi', 0)) + monkeypatch.setattr(mycli.sqlcompleter, 'suggest_path', lambda word: ['file.sql', 'folder/']) + monkeypatch.setattr(mycli.sqlcompleter, 'complete_path', lambda name, last_path: name if name == 'file.sql' else None) + + assert list(completer.find_files('./fi')) == [('file.sql', Fuzziness.PERFECT)] + assert completer.populate_scoped_cols([(None, 'select', None), (None, 'orders_view', None), (None, 'missing', None)]) == [ + 'id', + 'view_id', + ] + assert completer.populate_enum_values([(None, 'orders', 'o')], 'status', parent='other') == [] + assert completer.populate_enum_values([(None, 'orders', 'o')], 'status', parent='o') == ['pending', 'shipped'] + assert completer._quote_sql_string("O'Reilly") == "'O''Reilly'" + + +@pytest.mark.parametrize( + ('name', 'expected'), + [ + ('`quoted`', 'quoted'), + ('plain', 'plain'), + (None, ''), + ], +) +def test_strip_backticks(name: str | None, expected: str) -> None: + assert SQLCompleter._strip_backticks(name) == expected + + +@pytest.mark.parametrize( + ('parent', 'schema', 'relname', 'alias', 'expected'), + [ + ('o', None, 'orders', 'o', True), + ('orders', None, 'orders', None, True), + ('test.orders', 'test', 'orders', None, True), + ('other', 'test', 'orders', 'o', False), + ], +) +def test_matches_parent(parent: str, schema: str | None, relname: str, alias: str | None, expected: bool) -> None: + assert SQLCompleter._matches_parent(parent, schema, relname, alias) is expected + + +def test_copy_other_schemas_from_preserves_non_current_metadata() -> None: + source = SQLCompleter() + source.load_schema_metadata( + schema='other', + table_columns={'users': ['*', 'id', 'email']}, + foreign_keys={'tables': {}, 'relations': []}, + enum_values={}, + functions={'fn_foo': None}, + procedures={}, + ) + # Also populate the source's "current" schema; it should NOT be copied. + source.load_schema_metadata( + schema='current', + table_columns={'stale_current': ['*']}, + foreign_keys={'tables': {}, 'relations': []}, + enum_values={}, + functions={}, + procedures={}, + ) + + dest = SQLCompleter() + dest.set_dbname('current') + dest.extend_schemata('current') + + dest.copy_other_schemas_from(source, exclude='current') + + assert 'other' in dest.dbmetadata['tables'] + assert dest.dbmetadata['tables']['other'] == {'users': ['*', 'id', 'email']} + assert dest.dbmetadata['functions']['other'] == {'fn_foo': None} + # The excluded schema is not overwritten with stale source data. + assert dest.dbmetadata['tables']['current'] == {} + # Completion lookups pick up the copied names. + assert 'users' in dest.all_completions + assert 'email' in dest.all_completions + assert 'fn_foo' in dest.all_completions + + +def test_copy_other_schemas_from_does_not_overwrite_existing_dest() -> None: + source = SQLCompleter() + source.load_schema_metadata( + schema='shared', + table_columns={'from_source': ['*']}, + foreign_keys={'tables': {}, 'relations': []}, + enum_values={}, + functions={}, + procedures={}, + ) + + dest = SQLCompleter() + dest.set_dbname('current') + dest.dbmetadata['tables']['shared'] = {'from_dest': ['*']} + + dest.copy_other_schemas_from(source, exclude='current') + + # Destination's existing data wins over source when a conflict exists. + assert dest.dbmetadata['tables']['shared'] == {'from_dest': ['*']} + + +def test_load_schema_metadata_ignores_empty_schema() -> None: + completer = SQLCompleter() + + completer.load_schema_metadata( + schema='', + table_columns={'users': ['*', 'id']}, + foreign_keys={'tables': {'users': []}, 'relations': [('users', 'id')]}, + enum_values={'users': {'status': ['pending']}}, + functions={'fn_users': None}, + procedures={'proc_users': None}, + ) + + assert completer.dbmetadata['tables'] == {} + assert completer.dbmetadata['views'] == {} + assert completer.dbmetadata['functions'] == {} + assert completer.dbmetadata['procedures'] == {} + assert completer.dbmetadata['enum_values'] == {} + assert completer.dbmetadata['foreign_keys'] == {} + assert 'users' not in completer.all_completions + assert 'fn_users' not in completer.all_completions diff --git a/test/pytests/test_sqlexecute.py b/test/pytests/test_sqlexecute.py new file mode 100644 index 00000000..807dc3b7 --- /dev/null +++ b/test/pytests/test_sqlexecute.py @@ -0,0 +1,1712 @@ +# type: ignore + +import builtins +from datetime import time +import importlib.util +import os +from pathlib import Path +import sys +from types import SimpleNamespace + +from prompt_toolkit.formatted_text import FormattedText +import pymysql +import pytest + +from mycli.constants import TEST_DATABASE +from mycli.packages.special import iocommands +from mycli.packages.sqlresult import SQLResult +import mycli.sqlexecute as sqlexecute +from mycli.sqlexecute import ServerInfo, ServerSpecies, SQLExecute +from test.utils import dbtest, is_expanded_output, run, set_expanded_output + + +def assert_result_equal( + result, + preamble=None, + header=None, + rows=None, + status=None, + status_plain=None, + postamble=None, + auto_status=True, + assert_contains=False, +): + """Assert that an sqlexecute.run() result matches the expected values.""" + if status_plain is None and auto_status and rows: + status_plain = f"{len(rows)} row{'s' if len(rows) > 1 else ''} in set" + status = FormattedText([('', status_plain)]) + fields = { + "preamble": preamble, + "header": header, + "rows": rows, + "postamble": postamble, + "status": status, + "status_plain": status_plain, + } + + if assert_contains: + # Do a loose match on the results using the *in* operator. + for key, field in fields.items(): + if field: + assert field in result[0][key] + else: + # Do an exact match on the fields. + assert result == [fields] + + +@dbtest +def test_timediff_negative_value(executor): + sql = "select timediff('2020-11-11 01:01:01', '2020-11-11 01:02:01')" + result = run(executor, sql) + # negative value comes back as str + assert result[0]["rows"][0][0] == "-00:01:00" + + +@dbtest +def test_timediff_positive_value(executor): + sql = "select timediff('2020-11-11 01:02:01', '2020-11-11 01:01:01')" + result = run(executor, sql) + # positive value comes back as datetime.time + assert result[0]["rows"][0][0] == time(0, 1) + + +@dbtest +def test_get_result_status_without_warning(executor): + sql = "select 1" + result = run(executor, sql) + assert result[0]["status_plain"] == "1 row in set" + + +@dbtest +def test_get_result_status_with_warning(executor): + sql = "SELECT 1 + '0 foo'" + result = run(executor, sql) + assert result[0]["status"] == FormattedText([ + ('', '1 row in set'), + ('', ', '), + ('class:output.status.warning-count', '1 warning'), + ]) + assert result[0]["status_plain"] == "1 row in set, 1 warning" + + +@dbtest +def test_conn(executor): + run(executor, """create table test(a text)""") + run(executor, """insert into test values('abc')""") + results = run(executor, """select * from test""") + + assert_result_equal(results, header=["a"], rows=[("abc",)]) + + +@dbtest +def test_bools(executor): + run(executor, """create table test(a boolean)""") + run(executor, """insert into test values(True)""") + results = run(executor, """select * from test""") + + assert_result_equal(results, header=["a"], rows=[(1,)]) + + +@dbtest +def test_binary(executor): + run(executor, """create table bt(geom linestring NOT NULL)""") + run(executor, "INSERT INTO bt VALUES (ST_GeomFromText('LINESTRING(116.37604 39.73979,116.375 39.73965)'));") + results = run(executor, """select * from bt""") + + geom = ( + b"\x00\x00\x00\x00\x01\x02\x00\x00\x00\x02\x00\x00\x009\x7f\x13\n" + b"\x11\x18]@4\xf4Op\xb1\xdeC@\x00\x00\x00\x00\x00\x18]@B>\xe8\xd9" + b"\xac\xdeC@" + ) + + assert_result_equal(results, header=["geom"], rows=[(geom,)]) + + +@dbtest +def test_table_and_columns_query(executor): + run(executor, "create table a(x text, y text)") + run(executor, "create table b(z text)") + + assert set(executor.tables()) == {("a",), ("b",)} + assert set(executor.table_columns()) == {("a", "x"), ("a", "y"), ("b", "z")} + + +@dbtest +def test_database_list(executor): + databases = executor.databases() + assert TEST_DATABASE in databases + + +@dbtest +def test_invalid_syntax(executor): + with pytest.raises(pymysql.ProgrammingError) as excinfo: + run(executor, "invalid syntax!") + assert "You have an error in your SQL syntax;" in str(excinfo.value) + + +@dbtest +def test_invalid_column_name(executor): + with pytest.raises(pymysql.err.OperationalError) as excinfo: + run(executor, "select invalid command") + assert "Unknown column 'invalid' in 'field list'" in str(excinfo.value) + + +@dbtest +def test_unicode_support_in_output(executor): + run(executor, "create table unicodechars(t text)") + run(executor, "insert into unicodechars (t) values ('é')") + + # See issue #24, this raises an exception without proper handling + results = run(executor, "select * from unicodechars") + assert_result_equal(results, header=["t"], rows=[("é",)]) + + +@dbtest +def test_multiple_queries_same_line(executor): + results = run(executor, "select 'foo'; select 'bar'") + + expected = [ + { + "preamble": None, + "header": ["foo"], + "rows": [("foo",)], + "postamble": None, + "status_plain": "1 row in set", + 'status': FormattedText([('', '1 row in set')]), + }, + { + "preamble": None, + "header": ["bar"], + "rows": [("bar",)], + "postamble": None, + "status_plain": "1 row in set", + 'status': FormattedText([('', '1 row in set')]), + }, + ] + assert expected == results + + +@dbtest +def test_multiple_queries_same_line_syntaxerror(executor): + with pytest.raises(pymysql.ProgrammingError) as excinfo: + run(executor, "select 'foo'; invalid syntax") + assert "You have an error in your SQL syntax;" in str(excinfo.value) + + +@dbtest +@pytest.mark.skipif(os.name == "nt", reason="Bug: fails on Windows, needs fixing, singleton of FQ not working right") +def test_favorite_query(executor, monkeypatch): + monkeypatch.setattr(iocommands.FavoriteQueries, 'instance', iocommands.favoritequeries, raising=False) + set_expanded_output(False) + run(executor, "create table test(a text)") + run(executor, "insert into test values('abc')") + run(executor, "insert into test values('def')") + + results = run(executor, "\\fs test-a select * from test where a like 'a%'") + assert_result_equal(results, status="Saved.", status_plain="Saved.") + + results = run(executor, "\\f test-a") + assert_result_equal(results, preamble="> select * from test where a like 'a%'", header=["a"], rows=[("abc",)], auto_status=False) + + results = run(executor, "\\fd test-a") + assert_result_equal(results, status="test-a: Deleted.", status_plain="test-a: Deleted.") + + +@dbtest +@pytest.mark.skipif(os.name == "nt", reason="Bug: fails on Windows, needs fixing, singleton of FQ not working right") +def test_favorite_query_multiple_statement(executor, monkeypatch): + monkeypatch.setattr(iocommands.FavoriteQueries, 'instance', iocommands.favoritequeries, raising=False) + set_expanded_output(False) + run(executor, "create table test(a text)") + run(executor, "insert into test values('abc')") + run(executor, "insert into test values('def')") + + results = run(executor, "\\fs test-ad select * from test where a like 'a%'; select * from test where a like 'd%'") + assert_result_equal(results, status="Saved.", status_plain="Saved.") + + results = run(executor, "\\f test-ad") + expected = [ + { + "preamble": "> select * from test where a like 'a%'", + "header": ["a"], + "rows": [("abc",)], + "postamble": None, + "status": None, + "status_plain": None, + }, + { + "preamble": "> select * from test where a like 'd%'", + "header": ["a"], + "rows": [("def",)], + "postamble": None, + "status": None, + "status_plain": None, + }, + ] + assert expected == results + + results = run(executor, "\\fd test-ad") + assert_result_equal(results, status="test-ad: Deleted.", status_plain="test-ad: Deleted.") + + +@dbtest +@pytest.mark.skipif(os.name == "nt", reason="Bug: fails on Windows, needs fixing, singleton of FQ not working right") +def test_favorite_query_expanded_output(executor, monkeypatch): + monkeypatch.setattr(iocommands.FavoriteQueries, 'instance', iocommands.favoritequeries, raising=False) + set_expanded_output(False) + run(executor, """create table test(a text)""") + run(executor, """insert into test values('abc')""") + + results = run(executor, "\\fs test-ae select * from test") + assert_result_equal(results, status="Saved.", status_plain="Saved.") + + results = run(executor, "\\f test-ae \\G") + assert is_expanded_output() is True + assert_result_equal(results, preamble="> select * from test", header=["a"], rows=[("abc",)], auto_status=False) + + set_expanded_output(False) + + results = run(executor, "\\fd test-ae") + assert_result_equal(results, status="test-ae: Deleted.", status_plain="test-ae: Deleted.") + + +@dbtest +def test_collapsed_output_special_command(executor): + set_expanded_output(True) + run(executor, "select 1\\g") + assert is_expanded_output() is False + + +@dbtest +def test_special_command(executor): + results = run(executor, "\\?") + assert_result_equal(results, rows=("quit", "\\q", "quit", "Quit."), header="Command", assert_contains=True, auto_status=False) + + +@dbtest +def test_cd_command_without_a_folder_name(executor): + results = run(executor, "system cd") + assert_result_equal( + results, status="Exactly one directory name must be provided.", status_plain="Exactly one directory name must be provided." + ) + + +@dbtest +def test_cd_command_with_one_nonexistent_folder_name(executor): + results = run(executor, 'system cd nonexistent_folder_name') + assert_result_equal(results, status='No such file or directory', status_plain='No such file or directory') + + +@dbtest +def test_cd_command_with_one_real_folder_name(executor, tmp_path, monkeypatch): + monkeypatch.chdir(tmp_path) + doc_dir = tmp_path / 'doc' + doc_dir.mkdir() + results = run(executor, 'system cd doc') + # todo would be better to capture stderr but there was a problem with capsys + assert results[0]['status_plain'] is None + + +@dbtest +def test_cd_command_with_two_folder_names(executor): + results = run(executor, "system cd one two") + assert_result_equal( + results, status='Exactly one directory name must be provided.', status_plain='Exactly one directory name must be provided.' + ) + + +@dbtest +def test_cd_command_unbalanced(executor): + results = run(executor, "system cd 'one") + assert_result_equal( + results, + status='Cannot parse system command: No closing quotation', + status_plain='Cannot parse system command: No closing quotation', + ) + + +@dbtest +def test_system_command_not_found(executor): + results = run(executor, "system xyz") + if os.name == "nt": + assert_result_equal(results, status_plain="OSError: The system cannot find the file specified", assert_contains=True) + else: + assert_result_equal(results, status_plain="OSError: No such file or directory", assert_contains=True) + + +@dbtest +def test_system_command_output(executor): + eol = os.linesep + results = run(executor, "system echo mycli rocks") + assert_result_equal(results, preamble=f"mycli rocks{eol}") + + +@dbtest +def test_cd_command_current_dir(executor): + test_path = os.path.abspath(os.path.dirname(__file__)) + run(executor, f"system cd {test_path}") + assert os.getcwd() == test_path + + +@dbtest +def test_unicode_support(executor): + results = run(executor, "SELECT '日本語' AS japanese;") + assert_result_equal(results, header=["japanese"], rows=[("日本語",)]) + + +@dbtest +def test_timestamp_null(executor): + run(executor, """create table ts_null(a timestamp null)""") + run(executor, """insert into ts_null values(null)""") + results = run(executor, """select * from ts_null""") + assert_result_equal(results, header=["a"], rows=[(None,)]) + + +@dbtest +def test_datetime_null(executor): + run(executor, """create table dt_null(a datetime null)""") + run(executor, """insert into dt_null values(null)""") + results = run(executor, """select * from dt_null""") + assert_result_equal(results, header=["a"], rows=[(None,)]) + + +@dbtest +def test_date_null(executor): + run(executor, """create table date_null(a date null)""") + run(executor, """insert into date_null values(null)""") + results = run(executor, """select * from date_null""") + assert_result_equal(results, header=["a"], rows=[(None,)]) + + +@dbtest +def test_time_null(executor): + run(executor, """create table time_null(a time null)""") + run(executor, """insert into time_null values(null)""") + results = run(executor, """select * from time_null""") + assert_result_equal(results, header=["a"], rows=[(None,)]) + + +@dbtest +def test_multiple_results(executor): + query = """CREATE PROCEDURE dmtest() + BEGIN + SELECT 1; + SELECT 2; + END""" + executor.conn.cursor().execute(query) + + results = run(executor, "call dmtest;") + expected = [ + { + "preamble": None, + "header": ["1"], + "rows": [(1,)], + "postamble": None, + "status_plain": "1 row in set", + 'status': FormattedText([('', '1 row in set')]), + }, + { + "preamble": None, + "header": ["2"], + "rows": [(2,)], + "postamble": None, + "status_plain": "1 row in set", + 'status': FormattedText([('', '1 row in set')]), + }, + ] + assert results == expected + + +@pytest.mark.parametrize( + "version_string, species, parsed_version_string, version", + ( + ("5.7.25-TiDB-v6.1.0", "TiDB", "6.1.0", 60100), + ("8.0.11-TiDB-v7.2.0-alpha-69-g96e9e68daa", "TiDB", "7.2.0", 70200), + ("5.7.32-35", "Percona", "5.7.32", 50732), + ("5.7.32-0ubuntu0.18.04.1", "MySQL", "5.7.32", 50732), + ("10.5.8-MariaDB-1:10.5.8+maria~focal", "MariaDB", "10.5.8", 100508), + ("5.5.5-10.5.8-MariaDB-1:10.5.8+maria~focal", "MariaDB", "10.5.8", 100508), + ("5.0.16-pro-nt-log", "MySQL", "5.0.16", 50016), + ("5.1.5a-alpha", "MySQL", "5.1.5", 50105), + ("unexpected version string", None, "", 0), + ("", None, "", 0), + (None, None, "", 0), + ), +) +def test_version_parsing(version_string, species, parsed_version_string, version): + server_info = ServerInfo.from_version_string(version_string) + assert (server_info.species and server_info.species.name) == species or ServerSpecies.MySQL + assert server_info.version_str == parsed_version_string + assert server_info.version == version + + +@pytest.mark.parametrize( + 'version_string, expected', + ( + ('5.7.32', 50732), + ('8.0.11', 80011), + ('10.5.8', 100508), + ), +) +def test_calc_mysql_version_value(version_string: str, expected: int) -> None: + assert ServerInfo.calc_mysql_version_value(version_string) == expected + + +@pytest.mark.parametrize( + 'version_string', + ( + None, + '', + 123, + '8.0', + '8.0.11.1', + 'unexpected version string', + ), +) +def test_calc_mysql_version_value_returns_zero_for_invalid_input(version_string: object) -> None: + assert ServerInfo.calc_mysql_version_value(version_string) == 0 + + +@pytest.mark.parametrize('version_string', ('8.0.x', '8.x.11', 'x.0.11')) +def test_calc_mysql_version_value_raises_for_non_numeric_parts(version_string: str) -> None: + with pytest.raises(ValueError): + ServerInfo.calc_mysql_version_value(version_string) + + +def test_sqlexecute_import_swallows_optional_dependency_import_errors(monkeypatch) -> None: + assert sqlexecute.__file__ is not None + original_import = builtins.__import__ + + def fake_import(name, globals=None, locals=None, fromlist=(), level=0): # noqa: A002 + if name == 'paramiko': + raise ImportError('missing optional dependency') + return original_import(name, globals, locals, fromlist, level) + + module_name = 'sqlexecute_importerror_test' + spec = importlib.util.spec_from_file_location(module_name, Path(sqlexecute.__file__)) + assert spec is not None + assert spec.loader is not None + module = importlib.util.module_from_spec(spec) + monkeypatch.setattr(builtins, '__import__', fake_import) + sys.modules[module_name] = module + try: + spec.loader.exec_module(module) + finally: + sys.modules.pop(module_name, None) + + +@pytest.mark.parametrize( + ('server_info', 'expected'), + ( + (ServerInfo(ServerSpecies.MySQL, '8.0.36'), 'MySQL 8.0.36'), + (ServerInfo(None, '8.0.36'), '8.0.36'), + ), +) +def test_server_info_string_representation(server_info: ServerInfo, expected: str) -> None: + assert str(server_info) == expected + + +@pytest.mark.parametrize( + 'column_type, expected', + ( + ("enum('small','medium','large')", ["small", "medium", "large"]), + ("ENUM('yes','no')", ["yes", "no"]), + ("enum('a,b','c')", ["a,b", "c"]), + ("enum('it''s','can\\\\t')", ["it's", "can\\t"]), + ), +) +def test_parse_enum_values(column_type: str, expected: list[str]) -> None: + assert SQLExecute._parse_enum_values(column_type) == expected + + +@pytest.mark.parametrize('column_type', ('', 'varchar(255)', "set('a','b')", None)) +def test_parse_enum_values_returns_empty_list_for_non_enum_input(column_type: str | None) -> None: + assert SQLExecute._parse_enum_values(column_type) == [] + + +class DummyConnection: + def __init__(self, server_version: str, close_error: Exception | None = None) -> None: + self.server_version = server_version + self.host = 'initial-host' + self.port = 3306 + self.close_calls = 0 + self.connect_calls = 0 + self.close_error = close_error + + def close(self) -> None: + self.close_calls += 1 + if self.close_error is not None: + raise self.close_error + + def connect(self) -> None: + self.connect_calls += 1 + + +class FakeQueryCursor: + def __init__( + self, + nextset_steps: list[tuple[bool, int, object | None]] | None = None, + ) -> None: + self.executed: list[str] = [] + self.rowcount = 1 + self.description: object | None = [('column',)] + self.warning_count = 0 + self._nextset_steps = list(nextset_steps or []) + + def execute(self, sql: str) -> None: + self.executed.append(sql) + + def nextset(self) -> bool: + if not self._nextset_steps: + return False + + has_next, rowcount, description = self._nextset_steps.pop(0) + self.rowcount = rowcount + self.description = description + return has_next + + +class FakeQueryConnection: + def __init__(self, cursors: list[FakeQueryCursor]) -> None: + self.cursors = list(cursors) + self.cursor_calls = 0 + + def cursor(self) -> FakeQueryCursor: + cursor = self.cursors[self.cursor_calls] + self.cursor_calls += 1 + return cursor + + +class FakeMetadataCursor: + def __init__( + self, + rows: list[tuple[object, ...]], + execute_error: Exception | None = None, + ) -> None: + self.rows = rows + self.execute_error = execute_error + self.executed: list[tuple[str, tuple[object, ...] | None]] = [] + self.entered = False + self.exited = False + + def __enter__(self) -> 'FakeMetadataCursor': + self.entered = True + return self + + def __exit__(self, exc_type: object, exc: object, tb: object) -> None: + self.exited = True + + def execute(self, sql: str, params: tuple[object, ...] | None = None) -> None: + self.executed.append((sql, params)) + if self.execute_error is not None: + raise self.execute_error + + def fetchall(self) -> list[tuple[object, ...]]: + return self.rows + + def fetchone(self) -> tuple[object, ...] | None: + if self.rows: + return self.rows[0] + return None + + def __iter__(self): + return iter(self.rows) + + +class FakeMetadataConnection: + def __init__(self, cursor: FakeMetadataCursor) -> None: + self._cursor = cursor + + def cursor(self) -> FakeMetadataCursor: + return self._cursor + + +class FakeConnectionIdCursor: + def __init__(self, row: tuple[int] | None) -> None: + self.row = row + + def fetchone(self) -> tuple[int] | None: + return self.row + + +class FakeSelectableConnection: + def __init__(self) -> None: + self.selected_databases: list[str] = [] + + def select_db(self, db: str) -> None: + self.selected_databases.append(db) + + +class FakeSSLContext: + def __init__(self) -> None: + self.check_hostname = True + self.verify_mode = None + self.minimum_version = None + self.maximum_version = None + self.loaded_cert_chain: tuple[str, str | None] | None = None + self.cipher_string: str | None = None + + def load_cert_chain(self, certfile: str, keyfile: str | None = None) -> None: + self.loaded_cert_chain = (certfile, keyfile) + + def set_ciphers(self, cipher_string: str) -> None: + self.cipher_string = cipher_string + + +def make_executor_for_connect_tests() -> SQLExecute: + executor = SQLExecute.__new__(SQLExecute) + executor.dbname = 'stored_db' + executor.user = 'stored_user' + executor.password = 'stored_password' + executor.host = 'stored_host' + executor.port = 3306 + executor.socket = '/tmp/mysql.sock' + executor.character_set = 'utf8mb4' + executor.local_infile = True + executor.ssl = {'ca': '/stored/ca.pem'} + executor.server_info = None + executor.connection_id = None + executor.ssh_user = 'stored_ssh_user' + executor.ssh_host = None + executor.ssh_port = 22 + executor.ssh_password = 'stored_ssh_password' + executor.ssh_key_filename = '/stored/key.pem' + executor.init_command = 'select 1' + executor.unbuffered = False + executor.sandbox_mode = False + executor.conn = None + return executor + + +def make_executor_for_run_tests(conn: object | None = None) -> SQLExecute: + executor = SQLExecute.__new__(SQLExecute) + executor.conn = conn + return executor + + +def test_connect_updates_connection_state_and_merges_overrides(monkeypatch) -> None: + executor = make_executor_for_connect_tests() + previous_conn = DummyConnection( + server_version='5.7.0', + close_error=pymysql.err.Error(), + ) + executor.conn = previous_conn + + new_conn = DummyConnection(server_version='8.0.36-0ubuntu0.22.04.1') + connect_kwargs = {} + reset_calls = [] + ssl_context = object() + ssl_params = {'ca': '/override/ca.pem'} + + def fake_connect(**kwargs): + connect_kwargs.update(kwargs) + return new_conn + + def fake_create_ssl_ctx(self, sslp): + assert self is executor + assert sslp == ssl_params + return ssl_context + + def fake_reset_connection_id(self) -> None: + assert self is executor + reset_calls.append(True) + self.connection_id = 42 + + monkeypatch.setattr(sqlexecute.pymysql, 'connect', fake_connect) + monkeypatch.setattr(SQLExecute, '_create_ssl_ctx', fake_create_ssl_ctx) + monkeypatch.setattr(SQLExecute, 'reset_connection_id', fake_reset_connection_id) + + executor.connect( + database='override_db', + user='override_user', + password='override_password', + host='override_host', + port=3307, + character_set='latin1', + local_infile=False, + ssl=ssl_params, + init_command='select 1; select 2', + unbuffered=True, + ) + + assert connect_kwargs['database'] == 'override_db' + assert connect_kwargs['user'] == 'override_user' + assert connect_kwargs['password'] == 'override_password' + assert connect_kwargs['host'] == 'override_host' + assert connect_kwargs['port'] == 3307 + assert connect_kwargs['unix_socket'] == '/tmp/mysql.sock' + assert connect_kwargs['charset'] == 'latin1' + assert connect_kwargs['local_infile'] is False + assert connect_kwargs['ssl'] is ssl_context + assert connect_kwargs['defer_connect'] is False + assert connect_kwargs['init_command'] == 'select 1; select 2' + assert connect_kwargs['cursorclass'] is sqlexecute.pymysql.cursors.SSCursor + assert connect_kwargs['client_flag'] & sqlexecute.pymysql.constants.CLIENT.INTERACTIVE + assert connect_kwargs['client_flag'] & sqlexecute.pymysql.constants.CLIENT.MULTI_STATEMENTS + assert connect_kwargs['program_name'] == 'mycli' + assert previous_conn.close_calls == 1 + assert executor.conn is new_conn + assert executor.dbname == 'override_db' + assert executor.user == 'override_user' + assert executor.password == 'override_password' + assert executor.host == 'override_host' + assert executor.port == 3307 + assert executor.socket == '/tmp/mysql.sock' + assert executor.character_set == 'latin1' + assert executor.ssl == ssl_params + assert executor.init_command == 'select 1; select 2' + assert executor.unbuffered is True + assert executor.connection_id == 42 + assert reset_calls == [True] + assert executor.server_info is not None + assert executor.server_info.version_str == '8.0.36' + assert executor.server_info.version == 80036 + + +def test_connect_sets_expired_password_flag(monkeypatch) -> None: + executor = make_executor_for_connect_tests() + executor.ssl = None + + new_conn = DummyConnection(server_version='8.0.36-0ubuntu0.22.04.1') + connect_kwargs = {} + + def fake_connect(**kwargs): + connect_kwargs.update(kwargs) + return new_conn + + monkeypatch.setattr(sqlexecute.pymysql, 'connect', fake_connect) + monkeypatch.setattr(SQLExecute, 'reset_connection_id', lambda self: None) + + executor.connect() + + assert connect_kwargs['client_flag'] & sqlexecute.pymysql.constants.CLIENT.HANDLE_EXPIRED_PASSWORDS + assert executor.sandbox_mode is False + + +def test_connect_falls_back_to_sandbox_on_1820(monkeypatch) -> None: + executor = make_executor_for_connect_tests() + executor.ssl = None + + new_conn = DummyConnection(server_version='8.0.36-0ubuntu0.22.04.1') + call_count = 0 + sandbox_calls = [] + + def fake_connect(**kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise pymysql.OperationalError(1820, 'must change password') + return new_conn + + def fake_connect_sandbox(self, conn): + sandbox_calls.append(conn) + + monkeypatch.setattr(sqlexecute.pymysql, 'connect', fake_connect) + monkeypatch.setattr(SQLExecute, '_connect_sandbox', fake_connect_sandbox) + + executor.connect() + + assert call_count == 2 + assert len(sandbox_calls) == 1 + assert executor.sandbox_mode is True + assert executor.server_info is None + assert executor.connection_id is None + + +def test_connect_reraises_non_sandbox_operational_error(monkeypatch) -> None: + executor = make_executor_for_connect_tests() + executor.ssl = None + + def fake_connect(**_kwargs): + raise pymysql.OperationalError(1045, 'access denied') + + monkeypatch.setattr(sqlexecute.pymysql, 'connect', fake_connect) + + with pytest.raises(pymysql.OperationalError) as exc_info: + executor.connect() + + assert exc_info.value.args == (1045, 'access denied') + + +def test_connect_uses_ssh_tunnel_when_ssh_host_is_set(monkeypatch) -> None: + executor = make_executor_for_connect_tests() + executor.ssl = None + new_conn = DummyConnection(server_version='8.0.36-0ubuntu0.22.04.1') + connect_kwargs = {} + tunnel_args = {} + tunnel_started = [] + + class FakeTunnel: + def __init__( + self, + ssh_address_or_host, + ssh_username=None, + ssh_pkey=None, + ssh_password=None, + remote_bind_address=None, + ) -> None: + tunnel_args['ssh_address_or_host'] = ssh_address_or_host + tunnel_args['ssh_username'] = ssh_username + tunnel_args['ssh_pkey'] = ssh_pkey + tunnel_args['ssh_password'] = ssh_password + tunnel_args['remote_bind_address'] = remote_bind_address + self.local_bind_host = '127.0.0.1' + self.local_bind_port = 4406 + + def start(self) -> None: + tunnel_started.append(True) + + def fake_connect(**kwargs): + connect_kwargs.update(kwargs) + return new_conn + + def fake_reset_connection_id(self) -> None: + self.connection_id = 7 + + monkeypatch.setattr(sqlexecute.pymysql, 'connect', fake_connect) + monkeypatch.setattr(SQLExecute, 'reset_connection_id', fake_reset_connection_id) + monkeypatch.setattr( + sqlexecute, + 'sshtunnel', + SimpleNamespace(SSHTunnelForwarder=FakeTunnel), + raising=False, + ) + + executor.connect( + host='db.internal', + port=3308, + ssh_host='bastion.internal', + ssh_port=2222, + ssh_user='alice', + ssh_password='secret', + ssh_key_filename='/tmp/id_rsa', + ) + + assert connect_kwargs['host'] == 'db.internal' + assert connect_kwargs['port'] == 3308 + assert connect_kwargs['defer_connect'] is True + assert connect_kwargs['init_command'] == 'select 1' + assert tunnel_args['ssh_address_or_host'] == ('bastion.internal', 2222) + assert tunnel_args['ssh_username'] == 'alice' + assert tunnel_args['ssh_pkey'] == '/tmp/id_rsa' + assert tunnel_args['ssh_password'] == 'secret' + assert tunnel_args['remote_bind_address'] == ('db.internal', 3308) + assert tunnel_started == [True] + assert new_conn.host == '127.0.0.1' + assert new_conn.port == 4406 + assert new_conn.connect_calls == 1 + assert executor.conn is new_conn + assert executor.host == 'db.internal' + assert executor.port == 3308 + assert executor.connection_id == 7 + + +def test_connect_reraises_ssh_tunnel_errors(monkeypatch) -> None: + executor = make_executor_for_connect_tests() + executor.ssl = None + new_conn = DummyConnection(server_version='8.0.36-0ubuntu0.22.04.1') + + class FakeTunnel: + def __init__(self, *args, **kwargs) -> None: + self.local_bind_host = '127.0.0.1' + self.local_bind_port = 4406 + + def start(self) -> None: + raise RuntimeError('tunnel failed') + + monkeypatch.setattr(sqlexecute.pymysql, 'connect', lambda **_kwargs: new_conn) + monkeypatch.setattr( + sqlexecute, + 'sshtunnel', + SimpleNamespace(SSHTunnelForwarder=FakeTunnel), + raising=False, + ) + + with pytest.raises(RuntimeError, match='tunnel failed'): + executor.connect(ssh_host='bastion.internal') + + +def test_connect_sandbox_temporarily_disables_set_character_set() -> None: + original_calls = [] + connect_observed_stub = [] + + class FakeSandboxConnection: + def set_character_set(self, *args, **kwargs) -> None: + original_calls.append((args, kwargs)) + + def connect(self) -> None: + self.set_character_set('utf8mb4') + connect_observed_stub.append(original_calls == []) + + conn = FakeSandboxConnection() + original_set_character_set = conn.set_character_set + + SQLExecute._connect_sandbox(conn) + + assert connect_observed_stub == [True] + assert conn.set_character_set == original_set_character_set + conn.set_character_set('latin1') + assert original_calls == [(('latin1',), {})] + + +def test_run_returns_empty_result_for_blank_statement(monkeypatch) -> None: + split_inputs: list[str] = [] + + def fake_split_queries(statement: str): + split_inputs.append(statement) + return iter(()) + + monkeypatch.setattr(sqlexecute.iocommands, 'split_queries', fake_split_queries) + + executor = make_executor_for_run_tests() + + assert list(executor.run(' \n\t ')) == [SQLResult()] + assert split_inputs == [''] + + +def test_run_does_not_split_favorite_query(monkeypatch) -> None: + favorite_results = [SQLResult(status='Saved.')] + favorite_sql = '\\fs test-name select 1; select 2' + cursor = FakeQueryCursor() + execute_calls: list[str] = [] + + def fake_execute(cur: FakeQueryCursor, sql: str) -> list[SQLResult]: + assert cur is cursor + execute_calls.append(sql) + return favorite_results + + def fail_split_queries(_statement: str): + raise AssertionError('split_queries() should not be called for favorite queries') + + monkeypatch.setattr(sqlexecute, 'Connection', FakeQueryConnection) + monkeypatch.setattr(sqlexecute, 'execute', fake_execute) + monkeypatch.setattr(sqlexecute.iocommands, 'split_queries', fail_split_queries) + + executor = make_executor_for_run_tests(FakeQueryConnection([cursor])) + + assert list(executor.run(favorite_sql)) == favorite_results + assert execute_calls == [favorite_sql] + assert cursor.executed == [] + + +def test_run_uses_special_command_results_without_regular_execution(monkeypatch) -> None: + cursor = FakeQueryCursor() + special_results = [SQLResult(status='special command')] + + def fake_execute(cur: FakeQueryCursor, sql: str) -> list[SQLResult]: + assert cur is cursor + assert sql == '\\dt' + return special_results + + def fail_get_result(_self: SQLExecute, _cursor: object) -> SQLResult: + raise AssertionError('get_result() should not be called for handled special commands') + + monkeypatch.setattr(sqlexecute, 'Connection', FakeQueryConnection) + monkeypatch.setattr(sqlexecute, 'execute', fake_execute) + monkeypatch.setattr(sqlexecute.iocommands, 'split_queries', lambda statement: iter([statement])) + monkeypatch.setattr(SQLExecute, 'get_result', fail_get_result) + + executor = make_executor_for_run_tests(FakeQueryConnection([cursor])) + + assert list(executor.run('\\dt')) == special_results + assert cursor.executed == [] + + +def test_run_falls_back_to_regular_sql_and_handles_output_flags(monkeypatch) -> None: + cursors = [FakeQueryCursor(), FakeQueryCursor()] + expanded_values: list[bool] = [] + forced_horizontal_values: list[bool] = [] + get_result_calls: list[list[str]] = [] + + def fake_execute(_cur: FakeQueryCursor, _sql: str) -> list[SQLResult]: + raise sqlexecute.CommandNotFound('not a special command') + + def fake_get_result(_self: SQLExecute, cursor: FakeQueryCursor) -> SQLResult: + get_result_calls.append(list(cursor.executed)) + return SQLResult(status=f'ran {cursor.executed[-1]}') + + monkeypatch.setattr(sqlexecute, 'Connection', FakeQueryConnection) + monkeypatch.setattr(sqlexecute, 'execute', fake_execute) + monkeypatch.setattr( + sqlexecute.iocommands, + 'split_queries', + lambda _statement: iter(['select 1\\G', 'select 2\\g']), + ) + monkeypatch.setattr( + sqlexecute.iocommands, + 'set_expanded_output', + lambda value: expanded_values.append(value), + ) + monkeypatch.setattr( + sqlexecute.iocommands, + 'set_forced_horizontal_output', + lambda value: forced_horizontal_values.append(value), + ) + monkeypatch.setattr(SQLExecute, 'get_result', fake_get_result) + + executor = make_executor_for_run_tests(FakeQueryConnection(cursors)) + + results = list(executor.run('select 1; select 2')) + + assert [result.status for result in results] == ['ran select 1', 'ran select 2'] + assert expanded_values == [True, False] + assert forced_horizontal_values == [True] + assert [cursor.executed for cursor in cursors] == [['select 1'], ['select 2']] + assert get_result_calls == [['select 1'], ['select 2']] + + +def test_run_yields_each_non_empty_result_set_until_nextset_is_false(monkeypatch) -> None: + cursor = FakeQueryCursor( + nextset_steps=[ + (True, 1, [('column',)]), + (False, 1, [('column',)]), + ] + ) + get_result_calls: list[int] = [] + + def fake_execute(_cur: FakeQueryCursor, _sql: str) -> list[SQLResult]: + raise sqlexecute.CommandNotFound('not a special command') + + def fake_get_result(_self: SQLExecute, _cursor: FakeQueryCursor) -> SQLResult: + get_result_calls.append(len(get_result_calls) + 1) + return SQLResult(status=f'result {len(get_result_calls)}') + + monkeypatch.setattr(sqlexecute, 'Connection', FakeQueryConnection) + monkeypatch.setattr(sqlexecute, 'execute', fake_execute) + monkeypatch.setattr(sqlexecute.iocommands, 'split_queries', lambda statement: iter([statement])) + monkeypatch.setattr(SQLExecute, 'get_result', fake_get_result) + + executor = make_executor_for_run_tests(FakeQueryConnection([cursor])) + + results = list(executor.run('call demo()')) + + assert [result.status for result in results] == ['result 1', 'result 2'] + assert cursor.executed == ['call demo()'] + assert get_result_calls == [1, 2] + + +def test_run_skips_trailing_empty_result_set_from_nextset(monkeypatch) -> None: + cursor = FakeQueryCursor(nextset_steps=[(True, 0, None)]) + get_result_calls: list[int] = [] + + def fake_execute(_cur: FakeQueryCursor, _sql: str) -> list[SQLResult]: + raise sqlexecute.CommandNotFound('not a special command') + + def fake_get_result(_self: SQLExecute, _cursor: FakeQueryCursor) -> SQLResult: + get_result_calls.append(1) + return SQLResult(status='result 1') + + monkeypatch.setattr(sqlexecute, 'Connection', FakeQueryConnection) + monkeypatch.setattr(sqlexecute, 'execute', fake_execute) + monkeypatch.setattr(sqlexecute.iocommands, 'split_queries', lambda statement: iter([statement])) + monkeypatch.setattr(SQLExecute, 'get_result', fake_get_result) + + executor = make_executor_for_run_tests(FakeQueryConnection([cursor])) + + results = list(executor.run('call demo()')) + + assert [result.status for result in results] == ['result 1'] + assert cursor.executed == ['call demo()'] + assert get_result_calls == [1] + + +def test_get_result_returns_header_and_row_status_for_result_sets() -> None: + cursor = FakeQueryCursor() + cursor.rowcount = 2 + cursor.description = [('name',), ('age',)] + cursor.warning_count = 0 + + executor = make_executor_for_run_tests() + + result = executor.get_result(cursor) + + assert result.preamble is None + assert result.header == ['name', 'age'] + assert result.rows is cursor + assert result.postamble is None + assert result.status_plain == '2 rows in set' + + +def test_get_result_returns_query_ok_status_when_no_result_set() -> None: + cursor = FakeQueryCursor() + cursor.rowcount = 1 + cursor.description = None + cursor.warning_count = 0 + + executor = make_executor_for_run_tests() + + result = executor.get_result(cursor) + + assert result.header is None + assert result.rows is cursor + assert result.status_plain == 'Query OK, 1 row affected' + + +def test_get_result_appends_warning_count_to_status() -> None: + cursor = FakeQueryCursor() + cursor.rowcount = 3 + cursor.description = [('name',)] + cursor.warning_count = 2 + + executor = make_executor_for_run_tests() + + result = executor.get_result(cursor) + + assert result.header == ['name'] + assert result.rows is cursor + assert result.status_plain == '3 rows in set, 2 warnings' + + +def test_tables_executes_show_tables_query_and_yields_rows(monkeypatch) -> None: + cursor = FakeMetadataCursor([('users',), ('orders',)]) + executor = make_executor_for_run_tests(FakeMetadataConnection(cursor)) + monkeypatch.setattr(sqlexecute, 'Connection', FakeMetadataConnection) + + result = list(executor.tables()) + + assert result == [('users',), ('orders',)] + assert cursor.executed == [(SQLExecute.tables_query, None)] + assert cursor.entered is True + assert cursor.exited is True + + +def test_tables_returns_empty_generator_when_no_tables_exist(monkeypatch) -> None: + cursor = FakeMetadataCursor([]) + executor = make_executor_for_run_tests(FakeMetadataConnection(cursor)) + monkeypatch.setattr(sqlexecute, 'Connection', FakeMetadataConnection) + + result = list(executor.tables()) + + assert result == [] + assert cursor.executed == [(SQLExecute.tables_query, None)] + + +def test_table_columns_executes_query_with_dbname_and_yields_rows(monkeypatch) -> None: + cursor = FakeMetadataCursor([('users', 'id'), ('users', 'email'), ('orders', 'id')]) + executor = make_executor_for_run_tests(FakeMetadataConnection(cursor)) + executor.dbname = 'app_db' + monkeypatch.setattr(sqlexecute, 'Connection', FakeMetadataConnection) + + result = list(executor.table_columns()) + + assert result == [('users', 'id'), ('users', 'email'), ('orders', 'id')] + assert cursor.executed == [(SQLExecute.table_columns_query, ('app_db',))] + assert cursor.entered is True + assert cursor.exited is True + + +def test_table_columns_returns_empty_generator_when_schema_has_no_tables(monkeypatch) -> None: + cursor = FakeMetadataCursor([]) + executor = make_executor_for_run_tests(FakeMetadataConnection(cursor)) + executor.dbname = 'empty_db' + monkeypatch.setattr(sqlexecute, 'Connection', FakeMetadataConnection) + + result = list(executor.table_columns()) + + assert result == [] + assert cursor.executed == [(SQLExecute.table_columns_query, ('empty_db',))] + + +def test_enum_values_executes_query_and_skips_non_enum_columns(monkeypatch) -> None: + cursor = FakeMetadataCursor([ + ('orders', 'status', "enum('new','paid')"), + ('orders', 'notes', 'varchar(255)'), + ]) + executor = make_executor_for_run_tests(FakeMetadataConnection(cursor)) + executor.dbname = 'app_db' + monkeypatch.setattr(sqlexecute, 'Connection', FakeMetadataConnection) + + result = list(executor.enum_values()) + + assert result == [('orders', 'status', ['new', 'paid'])] + assert cursor.executed == [(SQLExecute.enum_values_query, ('app_db',))] + assert cursor.entered is True + assert cursor.exited is True + + +def test_enum_values_returns_empty_generator_when_no_enum_values_are_found(monkeypatch) -> None: + cursor = FakeMetadataCursor([('orders', 'notes', 'varchar(255)')]) + executor = make_executor_for_run_tests(FakeMetadataConnection(cursor)) + executor.dbname = 'empty_db' + monkeypatch.setattr(sqlexecute, 'Connection', FakeMetadataConnection) + + result = list(executor.enum_values()) + + assert result == [] + assert cursor.executed == [(SQLExecute.enum_values_query, ('empty_db',))] + + +def test_foreign_keys_executes_query_with_dbname_and_yields_rows(monkeypatch) -> None: + cursor = FakeMetadataCursor([ + ('orders', 'customer_id', 'customers', 'id'), + ('order_items', 'order_id', 'orders', 'id'), + ]) + executor = make_executor_for_run_tests(FakeMetadataConnection(cursor)) + executor.dbname = 'app_db' + monkeypatch.setattr(sqlexecute, 'Connection', FakeMetadataConnection) + + result = list(executor.foreign_keys()) + + assert result == [ + ('orders', 'customer_id', 'customers', 'id'), + ('order_items', 'order_id', 'orders', 'id'), + ] + assert cursor.executed == [(SQLExecute.foreign_keys_query, ('app_db',))] + assert cursor.entered is True + assert cursor.exited is True + + +def test_foreign_keys_returns_empty_generator_and_logs_execute_errors(monkeypatch, caplog) -> None: + cursor = FakeMetadataCursor([], execute_error=RuntimeError('boom')) + executor = make_executor_for_run_tests(FakeMetadataConnection(cursor)) + executor.dbname = 'app_db' + monkeypatch.setattr(sqlexecute, 'Connection', FakeMetadataConnection) + + with caplog.at_level('ERROR', logger='mycli.sqlexecute'): + result = list(executor.foreign_keys()) + + assert result == [] + assert cursor.executed == [(SQLExecute.foreign_keys_query, ('app_db',))] + assert cursor.entered is True + assert cursor.exited is True + assert "No foreign key completions due to RuntimeError('boom')" in caplog.text + + +def test_databases_executes_show_databases_and_flattens_names(monkeypatch) -> None: + cursor = FakeMetadataCursor([('mysql',), ('information_schema',), ('app_db',)]) + executor = make_executor_for_run_tests(FakeMetadataConnection(cursor)) + monkeypatch.setattr(sqlexecute, 'Connection', FakeMetadataConnection) + + result = executor.databases() + + assert result == ['mysql', 'information_schema', 'app_db'] + assert cursor.executed == [(SQLExecute.databases_query, None)] + assert cursor.entered is True + assert cursor.exited is True + + +def test_databases_returns_empty_list_when_no_databases_are_found(monkeypatch) -> None: + cursor = FakeMetadataCursor([]) + executor = make_executor_for_run_tests(FakeMetadataConnection(cursor)) + monkeypatch.setattr(sqlexecute, 'Connection', FakeMetadataConnection) + + result = executor.databases() + + assert result == [] + assert cursor.executed == [(SQLExecute.databases_query, None)] + + +def test_functions_executes_query_with_dbname_and_yields_rows(monkeypatch) -> None: + cursor = FakeMetadataCursor([('calculate_total',), ('format_order',)]) + executor = make_executor_for_run_tests(FakeMetadataConnection(cursor)) + executor.dbname = 'app_db' + monkeypatch.setattr(sqlexecute, 'Connection', FakeMetadataConnection) + + result = list(executor.functions()) + + assert result == [('calculate_total',), ('format_order',)] + assert cursor.executed == [(SQLExecute.functions_query, ('app_db',))] + assert cursor.entered is True + assert cursor.exited is True + + +def test_functions_returns_empty_generator_when_schema_has_no_functions(monkeypatch) -> None: + cursor = FakeMetadataCursor([]) + executor = make_executor_for_run_tests(FakeMetadataConnection(cursor)) + executor.dbname = 'empty_db' + monkeypatch.setattr(sqlexecute, 'Connection', FakeMetadataConnection) + + result = list(executor.functions()) + + assert result == [] + assert cursor.executed == [(SQLExecute.functions_query, ('empty_db',))] + + +def test_procedures_executes_query_with_dbname_and_yields_rows(monkeypatch) -> None: + cursor = FakeMetadataCursor([('refresh_orders',), ('archive_orders',)]) + executor = make_executor_for_run_tests(FakeMetadataConnection(cursor)) + executor.dbname = 'app_db' + monkeypatch.setattr(sqlexecute, 'Connection', FakeMetadataConnection) + + result = list(executor.procedures()) + + assert result == [('refresh_orders',), ('archive_orders',)] + assert cursor.executed == [(SQLExecute.procedures_query, ('app_db',))] + assert cursor.entered is True + assert cursor.exited is True + + +def test_procedures_yields_empty_tuple_and_logs_database_errors(monkeypatch, caplog) -> None: + cursor = FakeMetadataCursor([], execute_error=pymysql.DatabaseError('boom')) + executor = make_executor_for_run_tests(FakeMetadataConnection(cursor)) + executor.dbname = 'app_db' + monkeypatch.setattr(sqlexecute, 'Connection', FakeMetadataConnection) + + with caplog.at_level('ERROR', logger='mycli.sqlexecute'): + result = list(executor.procedures()) + + assert result == [()] + assert cursor.executed == [(SQLExecute.procedures_query, ('app_db',))] + assert cursor.entered is True + assert cursor.exited is True + assert "No procedure completions due to DatabaseError('boom')" in caplog.text + + +def test_character_sets_executes_query_and_yields_rows(monkeypatch) -> None: + cursor = FakeMetadataCursor([('utf8mb4',), ('latin1',)]) + executor = make_executor_for_run_tests(FakeMetadataConnection(cursor)) + monkeypatch.setattr(sqlexecute, 'Connection', FakeMetadataConnection) + + result = list(executor.character_sets()) + + assert result == [('utf8mb4',), ('latin1',)] + assert cursor.executed == [(SQLExecute.character_sets_query, None)] + assert cursor.entered is True + assert cursor.exited is True + + +def test_character_sets_yields_empty_tuple_and_logs_database_errors(monkeypatch, caplog) -> None: + cursor = FakeMetadataCursor([], execute_error=pymysql.DatabaseError('boom')) + executor = make_executor_for_run_tests(FakeMetadataConnection(cursor)) + monkeypatch.setattr(sqlexecute, 'Connection', FakeMetadataConnection) + + with caplog.at_level('ERROR', logger='mycli.sqlexecute'): + result = list(executor.character_sets()) + + assert result == [()] + assert cursor.executed == [(SQLExecute.character_sets_query, None)] + assert cursor.entered is True + assert cursor.exited is True + assert "No character_set completions due to DatabaseError('boom')" in caplog.text + + +def test_collations_executes_query_and_yields_rows(monkeypatch) -> None: + cursor = FakeMetadataCursor([('utf8mb4_general_ci',), ('latin1_swedish_ci',)]) + executor = make_executor_for_run_tests(FakeMetadataConnection(cursor)) + monkeypatch.setattr(sqlexecute, 'Connection', FakeMetadataConnection) + + result = list(executor.collations()) + + assert result == [('utf8mb4_general_ci',), ('latin1_swedish_ci',)] + assert cursor.executed == [(SQLExecute.collations_query, None)] + assert cursor.entered is True + assert cursor.exited is True + + +def test_collations_yields_empty_tuple_and_logs_database_errors(monkeypatch, caplog) -> None: + cursor = FakeMetadataCursor([], execute_error=pymysql.DatabaseError('boom')) + executor = make_executor_for_run_tests(FakeMetadataConnection(cursor)) + monkeypatch.setattr(sqlexecute, 'Connection', FakeMetadataConnection) + + with caplog.at_level('ERROR', logger='mycli.sqlexecute'): + result = list(executor.collations()) + + assert result == [()] + assert cursor.executed == [(SQLExecute.collations_query, None)] + assert cursor.entered is True + assert cursor.exited is True + assert "No collations completions due to DatabaseError('boom')" in caplog.text + + +def test_show_candidates_executes_query_and_strips_show_prefix(monkeypatch) -> None: + cursor = FakeMetadataCursor([('SHOW DATABASES',), ('SHOW FULL TABLES',)]) + executor = make_executor_for_run_tests(FakeMetadataConnection(cursor)) + monkeypatch.setattr(sqlexecute, 'Connection', FakeMetadataConnection) + + result = list(executor.show_candidates()) + + assert result == [('DATABASES',), ('FULL TABLES',)] + assert cursor.executed == [(SQLExecute.show_candidates_query, None)] + assert cursor.entered is True + assert cursor.exited is True + + +def test_show_candidates_yields_empty_tuple_and_logs_database_errors(monkeypatch, caplog) -> None: + cursor = FakeMetadataCursor([], execute_error=pymysql.DatabaseError('boom')) + executor = make_executor_for_run_tests(FakeMetadataConnection(cursor)) + monkeypatch.setattr(sqlexecute, 'Connection', FakeMetadataConnection) + + with caplog.at_level('ERROR', logger='mycli.sqlexecute'): + result = list(executor.show_candidates()) + + assert result == [()] + assert cursor.executed == [(SQLExecute.show_candidates_query, None)] + assert cursor.entered is True + assert cursor.exited is True + assert "No show completions due to DatabaseError('boom')" in caplog.text + + +def test_users_executes_query_and_yields_rows(monkeypatch) -> None: + cursor = FakeMetadataCursor([("'alice'@'localhost'",), ("'bob'@'%'",)]) + executor = make_executor_for_run_tests(FakeMetadataConnection(cursor)) + monkeypatch.setattr(sqlexecute, 'Connection', FakeMetadataConnection) + + result = list(executor.users()) + + assert result == [("'alice'@'localhost'",), ("'bob'@'%'",)] + assert cursor.executed == [(SQLExecute.users_query, None)] + assert cursor.entered is True + assert cursor.exited is True + + +def test_users_yields_empty_tuple_and_logs_database_errors(monkeypatch, caplog) -> None: + cursor = FakeMetadataCursor([], execute_error=pymysql.DatabaseError('boom')) + executor = make_executor_for_run_tests(FakeMetadataConnection(cursor)) + monkeypatch.setattr(sqlexecute, 'Connection', FakeMetadataConnection) + + with caplog.at_level('ERROR', logger='mycli.sqlexecute'): + result = list(executor.users()) + + assert result == [()] + assert cursor.executed == [(SQLExecute.users_query, None)] + assert cursor.entered is True + assert cursor.exited is True + assert "No user completions due to DatabaseError('boom')" in caplog.text + + +def test_now_returns_database_timestamp_from_first_row(monkeypatch) -> None: + timestamp = sqlexecute.datetime.datetime(2024, 1, 2, 3, 4, 5) + cursor = FakeMetadataCursor([(timestamp,)]) + executor = make_executor_for_run_tests(FakeMetadataConnection(cursor)) + monkeypatch.setattr(sqlexecute, 'Connection', FakeMetadataConnection) + + result = executor.now() + + assert result == timestamp + assert cursor.executed == [(SQLExecute.now_query, None)] + assert cursor.entered is True + assert cursor.exited is True + + +def test_now_falls_back_to_local_datetime_when_query_returns_no_rows(monkeypatch) -> None: + fallback = sqlexecute.datetime.datetime(2024, 6, 7, 8, 9, 10) + cursor = FakeMetadataCursor([]) + executor = make_executor_for_run_tests(FakeMetadataConnection(cursor)) + + class FakeDateTime: + @classmethod + def now(cls) -> sqlexecute.datetime.datetime: + return fallback + + monkeypatch.setattr(sqlexecute, 'Connection', FakeMetadataConnection) + monkeypatch.setattr(sqlexecute.datetime, 'datetime', FakeDateTime) + + result = executor.now() + + assert result == fallback + assert cursor.executed == [(SQLExecute.now_query, None)] + + +def test_get_connection_id_returns_cached_value_without_reset(monkeypatch) -> None: + executor = make_executor_for_run_tests() + executor.connection_id = 123 + + def fail_reset_connection_id(self) -> None: + raise AssertionError('reset_connection_id() should not be called') + + monkeypatch.setattr(SQLExecute, 'reset_connection_id', fail_reset_connection_id) + + assert executor.get_connection_id() == 123 + + +def test_get_connection_id_resets_when_connection_id_is_missing(monkeypatch) -> None: + executor = make_executor_for_run_tests() + executor.connection_id = None + reset_calls: list[bool] = [] + + def fake_reset_connection_id(self) -> None: + reset_calls.append(True) + self.connection_id = 456 + + monkeypatch.setattr(SQLExecute, 'reset_connection_id', fake_reset_connection_id) + + assert executor.get_connection_id() == 456 + assert reset_calls == [True] + + +def test_reset_connection_id_sets_connection_id_from_query_result(monkeypatch) -> None: + executor = make_executor_for_run_tests() + executor.connection_id = None + run_calls: list[str] = [] + + def fake_run(sql: str): + run_calls.append(sql) + return [SimpleNamespace(rows=FakeConnectionIdCursor((789,)))] + + monkeypatch.setattr(sqlexecute, 'Cursor', FakeConnectionIdCursor) + monkeypatch.setattr(executor, 'run', fake_run) + + executor.reset_connection_id() + + assert executor.connection_id == 789 + assert run_calls == ['select connection_id()'] + + +def test_reset_connection_id_sets_minus_one_when_query_returns_no_row(monkeypatch) -> None: + executor = make_executor_for_run_tests() + executor.connection_id = None + + monkeypatch.setattr(sqlexecute, 'Cursor', FakeConnectionIdCursor) + monkeypatch.setattr( + executor, + 'run', + lambda _sql: [SimpleNamespace(rows=FakeConnectionIdCursor(None))], + ) + + executor.reset_connection_id() + + assert executor.connection_id == -1 + + +def test_reset_connection_id_leaves_connection_id_unset_when_query_returns_no_results(monkeypatch) -> None: + executor = make_executor_for_run_tests() + executor.connection_id = None + + monkeypatch.setattr(executor, 'run', lambda _sql: iter(())) + + executor.reset_connection_id() + + assert executor.connection_id is None + + +def test_reset_connection_id_sets_minus_one_and_logs_errors_for_invalid_results(monkeypatch, caplog) -> None: + executor = make_executor_for_run_tests() + executor.connection_id = None + + monkeypatch.setattr(sqlexecute, 'Cursor', FakeConnectionIdCursor) + monkeypatch.setattr(executor, 'run', lambda _sql: [SimpleNamespace(rows=object())]) + + with caplog.at_level('ERROR', logger='mycli.sqlexecute'): + executor.reset_connection_id() + + assert executor.connection_id == -1 + assert 'Failed to get connection id:' in caplog.text + + +def test_change_db_selects_database_and_updates_dbname(monkeypatch) -> None: + conn = FakeSelectableConnection() + executor = make_executor_for_run_tests(conn) + executor.dbname = 'old_db' + monkeypatch.setattr(sqlexecute, 'Connection', FakeSelectableConnection) + + executor.change_db('new_db') + + assert conn.selected_databases == ['new_db'] + assert executor.dbname == 'new_db' + + +def test_create_ssl_ctx_without_ca_disables_hostname_check_and_verification(monkeypatch) -> None: + executor = make_executor_for_run_tests() + ctx = FakeSSLContext() + create_default_context_calls: list[tuple[str | None, str | None]] = [] + + def fake_create_default_context(cafile: str | None = None, capath: str | None = None) -> FakeSSLContext: + create_default_context_calls.append((cafile, capath)) + return ctx + + monkeypatch.setattr(sqlexecute.ssl, 'create_default_context', fake_create_default_context) + + result = executor._create_ssl_ctx({}) + + assert result is ctx + assert create_default_context_calls == [(None, None)] + assert ctx.check_hostname is False + assert ctx.verify_mode == sqlexecute.ssl.CERT_NONE + assert ctx.minimum_version == sqlexecute.ssl.TLSVersion.TLSv1_2 + assert ctx.maximum_version is None + assert ctx.loaded_cert_chain is None + assert ctx.cipher_string is None + + +def test_create_ssl_ctx_applies_cert_cipher_and_tls_version(monkeypatch) -> None: + executor = make_executor_for_run_tests() + ctx = FakeSSLContext() + create_default_context_calls: list[tuple[str | None, str | None]] = [] + + def fake_create_default_context(cafile: str | None = None, capath: str | None = None) -> FakeSSLContext: + create_default_context_calls.append((cafile, capath)) + return ctx + + monkeypatch.setattr( + sqlexecute.ssl, + 'create_default_context', + fake_create_default_context, + ) + + result = executor._create_ssl_ctx({ + 'ca': '/tmp/ca.pem', + 'check_hostname': False, + 'cert': '/tmp/client-cert.pem', + 'key': '/tmp/client-key.pem', + 'cipher': 'ECDHE-RSA-AES256-GCM-SHA384', + 'tls_version': 'TLSv1.3', + }) + + assert result is ctx + assert create_default_context_calls == [('/tmp/ca.pem', None)] + assert ctx.check_hostname is False + assert ctx.verify_mode == sqlexecute.ssl.CERT_REQUIRED + assert ctx.loaded_cert_chain == ('/tmp/client-cert.pem', '/tmp/client-key.pem') + assert ctx.cipher_string == 'ECDHE-RSA-AES256-GCM-SHA384' + assert ctx.minimum_version == sqlexecute.ssl.TLSVersion.TLSv1_3 + assert ctx.maximum_version == sqlexecute.ssl.TLSVersion.TLSv1_3 + + +@pytest.mark.parametrize( + ('tls_version', 'expected_version'), + ( + ('TLSv1', sqlexecute.ssl.TLSVersion.TLSv1), + ('TLSv1.1', sqlexecute.ssl.TLSVersion.TLSv1_1), + ('TLSv1.2', sqlexecute.ssl.TLSVersion.TLSv1_2), + ), +) +def test_create_ssl_ctx_supports_legacy_tls_version_overrides(monkeypatch, tls_version: str, expected_version) -> None: + executor = make_executor_for_run_tests() + ctx = FakeSSLContext() + + monkeypatch.setattr(sqlexecute.ssl, 'create_default_context', lambda **_kwargs: ctx) + + result = executor._create_ssl_ctx({'tls_version': tls_version}) + + assert result is ctx + assert ctx.minimum_version == expected_version + assert ctx.maximum_version == expected_version + + +def test_create_ssl_ctx_logs_invalid_tls_version_and_keeps_default_minimum(monkeypatch, caplog) -> None: + executor = make_executor_for_run_tests() + ctx = FakeSSLContext() + + monkeypatch.setattr(sqlexecute.ssl, 'create_default_context', lambda **_kwargs: ctx) + + with caplog.at_level('ERROR', logger='mycli.sqlexecute'): + result = executor._create_ssl_ctx({'tls_version': 'SSLv3'}) + + assert result is ctx + assert ctx.minimum_version == sqlexecute.ssl.TLSVersion.TLSv1_2 + assert ctx.maximum_version is None + assert 'Invalid tls version: SSLv3' in caplog.text + + +def test_close_calls_connection_close_when_present() -> None: + conn = DummyConnection(server_version='8.0.0') + executor = make_executor_for_run_tests(conn) + + executor.close() + + assert conn.close_calls == 1 + + +def test_close_swallows_pymysql_errors() -> None: + conn = DummyConnection(server_version='8.0.0', close_error=pymysql.err.Error()) + executor = make_executor_for_run_tests(conn) + + executor.close() + + assert conn.close_calls == 1 + + +def test_close_does_nothing_when_connection_is_none() -> None: + executor = make_executor_for_run_tests() + + executor.close() diff --git a/test/pytests/test_sqlresult.py b/test/pytests/test_sqlresult.py new file mode 100644 index 00000000..9c19293a --- /dev/null +++ b/test/pytests/test_sqlresult.py @@ -0,0 +1,29 @@ +from prompt_toolkit.formatted_text import FormattedText + +from mycli.packages.sqlresult import SQLResult + + +def test_sqlresult_str_includes_all_fields() -> None: + result = SQLResult( + preamble='before', + header=['id'], + rows=[(1,)], + postamble='after', + status='ok', + command={'name': 'watch', 'seconds': 1.0}, + ) + + assert 'before' in str(result) + assert "['id']" in str(result) + assert '[(1,)]' in str(result) + assert 'after' in str(result) + assert 'ok' in str(result) + assert "{'name': 'watch', 'seconds': 1.0}" in str(result) + + +def test_sqlresult_status_plain_handles_none_and_formatted_text() -> None: + empty = SQLResult() + formatted = SQLResult(status=FormattedText([('', '1 row in set'), ('', ', '), ('class:warn', '1 warning')])) + + assert empty.status_plain is None + assert formatted.status_plain == '1 row in set, 1 warning' diff --git a/test/pytests/test_ssh_utils.py b/test/pytests/test_ssh_utils.py new file mode 100644 index 00000000..1f26ce0b --- /dev/null +++ b/test/pytests/test_ssh_utils.py @@ -0,0 +1,90 @@ +from __future__ import annotations + +import builtins +import importlib +from pathlib import Path +import sys +from typing import TextIO + +import pytest + +from mycli.packages import paramiko_stub, ssh_utils + + +class FakeSSHConfig: + def __init__(self, parse_error: Exception | None = None) -> None: + self.parse_error = parse_error + self.parsed_text: str | None = None + + def parse(self, handle: TextIO) -> None: + if self.parse_error is not None: + raise self.parse_error + self.parsed_text = handle.read() + + +def test_read_ssh_config_parses_and_returns_config(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: + config_path = tmp_path / 'ssh_config' + config_path.write_text('Host demo\n HostName example.com\n', encoding='utf-8') + fake_ssh_config = FakeSSHConfig() + + monkeypatch.setattr(ssh_utils.paramiko.config, 'SSHConfig', lambda: fake_ssh_config) + + result = ssh_utils.read_ssh_config(str(config_path)) + + assert result is fake_ssh_config + assert fake_ssh_config.parsed_text == 'Host demo\n HostName example.com\n' + + +def test_read_ssh_config_reports_missing_file_and_exits(monkeypatch: pytest.MonkeyPatch) -> None: + secho_calls: list[tuple[str, bool, str]] = [] + + monkeypatch.setattr( + ssh_utils.click, + 'secho', + lambda message, err, fg: secho_calls.append((message, err, fg)), + ) + + with pytest.raises(SystemExit) as excinfo: + ssh_utils.read_ssh_config('/definitely/missing/ssh_config') + + assert excinfo.value.code == 1 + assert secho_calls == [("[Errno 2] No such file or directory: '/definitely/missing/ssh_config'", True, 'red')] + + +def test_read_ssh_config_reports_parse_errors_and_exits(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: + config_path = tmp_path / 'ssh_config' + config_path.write_text('Host broken\n', encoding='utf-8') + fake_ssh_config = FakeSSHConfig(parse_error=RuntimeError('bad config')) + secho_calls: list[tuple[str, bool, str]] = [] + + monkeypatch.setattr(ssh_utils.paramiko.config, 'SSHConfig', lambda: fake_ssh_config) + monkeypatch.setattr( + ssh_utils.click, + 'secho', + lambda message, err, fg: secho_calls.append((message, err, fg)), + ) + + with pytest.raises(SystemExit) as excinfo: + ssh_utils.read_ssh_config(str(config_path)) + + assert excinfo.value.code == 1 + assert secho_calls == [(f'Could not parse SSH configuration file {config_path}:\nbad config ', True, 'red')] + + +def test_ssh_utils_falls_back_to_paramiko_stub_when_paramiko_is_unavailable(monkeypatch: pytest.MonkeyPatch) -> None: + original_import = builtins.__import__ + + def fake_import(name: str, globals_=None, locals_=None, fromlist=(), level: int = 0): + if name == 'paramiko': + raise ImportError('paramiko not installed') + return original_import(name, globals_, locals_, fromlist, level) + + monkeypatch.delitem(sys.modules, 'paramiko', raising=False) + monkeypatch.setattr(builtins, '__import__', fake_import) + + reloaded = importlib.reload(ssh_utils) + + assert reloaded.paramiko is paramiko_stub.paramiko + + monkeypatch.undo() + importlib.reload(ssh_utils) diff --git a/test/pytests/test_string_utils.py b/test/pytests/test_string_utils.py new file mode 100644 index 00000000..338a797a --- /dev/null +++ b/test/pytests/test_string_utils.py @@ -0,0 +1,27 @@ +# type: ignore + +from mycli.packages.string_utils import sanitize_terminal_title + + +def test_sanitize_terminal_title_strips_ansi_sequences() -> None: + title = '\x1b[31mmycli\x1b[0m session' + + assert sanitize_terminal_title(title) == 'mycli session' + + +def test_sanitize_terminal_title_replaces_newlines_with_spaces() -> None: + title = 'schema\nquery\r\nprompt' + + assert sanitize_terminal_title(title) == 'schema query prompt' + + +def test_sanitize_terminal_title_removes_control_characters() -> None: + title = 'my\x00cl\ti\x1f title\x7f' + + assert sanitize_terminal_title(title) == 'mycli title' + + +def test_sanitize_terminal_title_preserves_printable_text() -> None: + title = 'db-01 / reporting' + + assert sanitize_terminal_title(title) == 'db-01 / reporting' diff --git a/test/pytests/test_tabular_output.py b/test/pytests/test_tabular_output.py new file mode 100644 index 00000000..f1f3d8c5 --- /dev/null +++ b/test/pytests/test_tabular_output.py @@ -0,0 +1,178 @@ +# type: ignore + +"""Test the sql output adapter.""" + +import os +from textwrap import dedent + +from cli_helpers.utils import strip_ansi +from pymysql.constants import FIELD_TYPE +import pytest + +from mycli.main import MyCli +from mycli.packages.sqlresult import SQLResult +from test.utils import HOST, PASSWORD, PORT, USER, dbtest + +default_config_file = os.path.join(os.path.dirname(__file__), "../myclirc") + + +@pytest.fixture +def mycli(): + cli = MyCli() + cli.connect(None, USER, PASSWORD, HOST, PORT, None, init_command=None) + yield cli + cli.sqlexecute.conn.close() + + +@dbtest +def test_sql_output(mycli): + """Test the sql output adapter.""" + header = ["letters", "number", "optional", "float", "binary"] + + class FakeCursor: + def __init__(self): + self.data = [("abc", 1, None, 10.0, b"\xaa"), ("d", 456, "1", 0.5, b"\xaa\xbb")] + self.description = [ + (None, FIELD_TYPE.VARCHAR), + (None, FIELD_TYPE.LONG), + (None, FIELD_TYPE.LONG), + (None, FIELD_TYPE.FLOAT), + (None, FIELD_TYPE.BLOB), + ] + + def __iter__(self): + return self + + def __next__(self): + if self.data: + return self.data.pop(0) + else: + raise StopIteration() + + def description(self): + return self.description + + # Test sql-update output format + assert list(mycli.change_table_format("sql-update")) == [SQLResult(status="Changed table format to sql-update")] + mycli.main_formatter.query = "" + mycli.redirect_formatter.query = "" + output = mycli.format_sqlresult(SQLResult(header=header, rows=FakeCursor())) + actual = "\n".join(output) + assert actual == dedent("""\ + UPDATE `DUAL` SET + `number` = 1 + , `optional` = NULL + , `float` = 10.0e0 + , `binary` = 0xaa + WHERE `letters` = 'abc'; + UPDATE `DUAL` SET + `number` = 456 + , `optional` = '1' + , `float` = 0.5e0 + , `binary` = 0xaabb + WHERE `letters` = 'd';""") + # Test sql-update-2 output format + assert list(mycli.change_table_format("sql-update-2")) == [SQLResult(status="Changed table format to sql-update-2")] + mycli.main_formatter.query = "" + mycli.redirect_formatter.query = "" + output = mycli.format_sqlresult(SQLResult(header=header, rows=FakeCursor())) + assert "\n".join(output) == dedent("""\ + UPDATE `DUAL` SET + `optional` = NULL + , `float` = 10.0e0 + , `binary` = 0xaa + WHERE `letters` = 'abc' AND `number` = 1; + UPDATE `DUAL` SET + `optional` = '1' + , `float` = 0.5e0 + , `binary` = 0xaabb + WHERE `letters` = 'd' AND `number` = 456;""") + # Test sql-insert output format (without table name) + assert list(mycli.change_table_format("sql-insert")) == [SQLResult(status="Changed table format to sql-insert")] + mycli.main_formatter.query = "" + mycli.redirect_formatter.query = "" + output = mycli.format_sqlresult(SQLResult(header=header, rows=FakeCursor())) + assert "\n".join(output) == dedent("""\ + INSERT INTO `DUAL` (`letters`, `number`, `optional`, `float`, `binary`) VALUES + ('abc', 1, NULL, 10.0e0, 0xaa) + , ('d', 456, '1', 0.5e0, 0xaabb) + ;""") + # Test sql-insert output format (with table name) + assert list(mycli.change_table_format("sql-insert")) == [SQLResult(status="Changed table format to sql-insert")] + mycli.main_formatter.query = "SELECT * FROM `table`" + mycli.redirect_formatter.query = "SELECT * FROM `table`" + output = mycli.format_sqlresult(SQLResult(header=header, rows=FakeCursor())) + assert "\n".join(output) == dedent("""\ + INSERT INTO table (`letters`, `number`, `optional`, `float`, `binary`) VALUES + ('abc', 1, NULL, 10.0e0, 0xaa) + , ('d', 456, '1', 0.5e0, 0xaabb) + ;""") + # Test sql-insert output format (with database + table name) + assert list(mycli.change_table_format("sql-insert")) == [SQLResult(status="Changed table format to sql-insert")] + mycli.main_formatter.query = "SELECT * FROM `database`.`table`" + mycli.redirect_formatter.query = "SELECT * FROM `database`.`table`" + output = mycli.format_sqlresult(SQLResult(header=header, rows=FakeCursor())) + assert "\n".join(output) == dedent("""\ + INSERT INTO database.table (`letters`, `number`, `optional`, `float`, `binary`) VALUES + ('abc', 1, NULL, 10.0e0, 0xaa) + , ('d', 456, '1', 0.5e0, 0xaabb) + ;""") + # Test binary output format is a hex string + assert list(mycli.change_table_format("psql")) == [SQLResult(status="Changed table format to psql")] + output = mycli.format_sqlresult(SQLResult(header=header, rows=FakeCursor())) + assert '0xaabb' in '\n'.join(output) + + +@dbtest +def test_postamble_output(mycli): + """Test the postamble output property.""" + header = ['letters', 'number', 'optional', 'float'] + + class FakeCursor: + def __init__(self): + self.data = [('abc', 1, None, 10.0)] + self.description = [ + (None, FIELD_TYPE.VARCHAR), + (None, FIELD_TYPE.LONG), + (None, FIELD_TYPE.LONG), + (None, FIELD_TYPE.FLOAT), + ] + + def __iter__(self): + return self + + def __next__(self): + if self.data: + return self.data.pop(0) + else: + raise StopIteration() + + def description(self): + return self.description + + postamble = 'postamble:\nfooter content' + mycli.change_table_format('ascii') + mycli.main_formatter.query = '' + output = mycli.format_sqlresult(SQLResult(header=header, rows=FakeCursor(), postamble=postamble)) + actual = "\n".join(output) + assert actual.endswith(postamble) + + +def test_tabulate_output_preserves_multiline_whitespace(monkeypatch, tmp_path): + monkeypatch.setenv("HOME", str(tmp_path)) + mycli = MyCli(myclirc=default_config_file) + mycli.helpers_style = None + mycli.helpers_warnings_style = None + + assert list(mycli.change_table_format("ascii")) == [SQLResult(status="Changed table format to ascii")] + + output = mycli.format_sqlresult(SQLResult(header=["text"], rows=[[" one\n two\nthree"]])) + + assert strip_ansi("\n".join(output)) == dedent("""\ + +------------+ + | text | + +------------+ + | one | + | two | + | three | + +------------+""") diff --git a/test/test.txt b/test/test.txt deleted file mode 100644 index 8d8b211e..00000000 --- a/test/test.txt +++ /dev/null @@ -1 +0,0 @@ -mycli rocks! diff --git a/test/test_clistyle.py b/test/test_clistyle.py deleted file mode 100644 index f82cdf0e..00000000 --- a/test/test_clistyle.py +++ /dev/null @@ -1,27 +0,0 @@ -"""Test the mycli.clistyle module.""" -import pytest - -from pygments.style import Style -from pygments.token import Token - -from mycli.clistyle import style_factory - - -@pytest.mark.skip(reason="incompatible with new prompt toolkit") -def test_style_factory(): - """Test that a Pygments Style class is created.""" - header = 'bold underline #ansired' - cli_style = {'Token.Output.Header': header} - style = style_factory('default', cli_style) - - assert isinstance(style(), Style) - assert Token.Output.Header in style.styles - assert header == style.styles[Token.Output.Header] - - -@pytest.mark.skip(reason="incompatible with new prompt toolkit") -def test_style_factory_unknown_name(): - """Test that an unrecognized name will not throw an error.""" - style = style_factory('foobar', {}) - - assert isinstance(style(), Style) diff --git a/test/test_completion_engine.py b/test/test_completion_engine.py deleted file mode 100644 index 318b6328..00000000 --- a/test/test_completion_engine.py +++ /dev/null @@ -1,555 +0,0 @@ -from mycli.packages.completion_engine import suggest_type -import pytest - - -def sorted_dicts(dicts): - """input is a list of dicts.""" - return sorted(tuple(x.items()) for x in dicts) - - -def test_select_suggests_cols_with_visible_table_scope(): - suggestions = suggest_type('SELECT FROM tabl', 'SELECT ') - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'alias', 'aliases': ['tabl']}, - {'type': 'column', 'tables': [(None, 'tabl', None)]}, - {'type': 'function', 'schema': []}, - {'type': 'keyword'}, - ]) - - -def test_select_suggests_cols_with_qualified_table_scope(): - suggestions = suggest_type('SELECT FROM sch.tabl', 'SELECT ') - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'alias', 'aliases': ['tabl']}, - {'type': 'column', 'tables': [('sch', 'tabl', None)]}, - {'type': 'function', 'schema': []}, - {'type': 'keyword'}, - ]) - - -@pytest.mark.parametrize('expression', [ - 'SELECT * FROM tabl WHERE ', - 'SELECT * FROM tabl WHERE (', - 'SELECT * FROM tabl WHERE foo = ', - 'SELECT * FROM tabl WHERE bar OR ', - 'SELECT * FROM tabl WHERE foo = 1 AND ', - 'SELECT * FROM tabl WHERE (bar > 10 AND ', - 'SELECT * FROM tabl WHERE (bar AND (baz OR (qux AND (', - 'SELECT * FROM tabl WHERE 10 < ', - 'SELECT * FROM tabl WHERE foo BETWEEN ', - 'SELECT * FROM tabl WHERE foo BETWEEN foo AND ', -]) -def test_where_suggests_columns_functions(expression): - suggestions = suggest_type(expression, expression) - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'alias', 'aliases': ['tabl']}, - {'type': 'column', 'tables': [(None, 'tabl', None)]}, - {'type': 'function', 'schema': []}, - {'type': 'keyword'}, - ]) - - -@pytest.mark.parametrize('expression', [ - 'SELECT * FROM tabl WHERE foo IN (', - 'SELECT * FROM tabl WHERE foo IN (bar, ', -]) -def test_where_in_suggests_columns(expression): - suggestions = suggest_type(expression, expression) - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'alias', 'aliases': ['tabl']}, - {'type': 'column', 'tables': [(None, 'tabl', None)]}, - {'type': 'function', 'schema': []}, - {'type': 'keyword'}, - ]) - - -def test_where_equals_any_suggests_columns_or_keywords(): - text = 'SELECT * FROM tabl WHERE foo = ANY(' - suggestions = suggest_type(text, text) - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'alias', 'aliases': ['tabl']}, - {'type': 'column', 'tables': [(None, 'tabl', None)]}, - {'type': 'function', 'schema': []}, - {'type': 'keyword'}]) - - -def test_lparen_suggests_cols(): - suggestion = suggest_type('SELECT MAX( FROM tbl', 'SELECT MAX(') - assert suggestion == [ - {'type': 'column', 'tables': [(None, 'tbl', None)]}] - - -def test_operand_inside_function_suggests_cols1(): - suggestion = suggest_type( - 'SELECT MAX(col1 + FROM tbl', 'SELECT MAX(col1 + ') - assert suggestion == [ - {'type': 'column', 'tables': [(None, 'tbl', None)]}] - - -def test_operand_inside_function_suggests_cols2(): - suggestion = suggest_type( - 'SELECT MAX(col1 + col2 + FROM tbl', 'SELECT MAX(col1 + col2 + ') - assert suggestion == [ - {'type': 'column', 'tables': [(None, 'tbl', None)]}] - - -def test_select_suggests_cols_and_funcs(): - suggestions = suggest_type('SELECT ', 'SELECT ') - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'alias', 'aliases': []}, - {'type': 'column', 'tables': []}, - {'type': 'function', 'schema': []}, - {'type': 'keyword'}, - ]) - - -@pytest.mark.parametrize('expression', [ - 'SELECT * FROM ', - 'INSERT INTO ', - 'COPY ', - 'UPDATE ', - 'DESCRIBE ', - 'DESC ', - 'EXPLAIN ', - 'SELECT * FROM foo JOIN ', -]) -def test_expression_suggests_tables_views_and_schemas(expression): - suggestions = suggest_type(expression, expression) - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'table', 'schema': []}, - {'type': 'view', 'schema': []}, - {'type': 'schema'}]) - - -@pytest.mark.parametrize('expression', [ - 'SELECT * FROM sch.', - 'INSERT INTO sch.', - 'COPY sch.', - 'UPDATE sch.', - 'DESCRIBE sch.', - 'DESC sch.', - 'EXPLAIN sch.', - 'SELECT * FROM foo JOIN sch.', -]) -def test_expression_suggests_qualified_tables_views_and_schemas(expression): - suggestions = suggest_type(expression, expression) - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'table', 'schema': 'sch'}, - {'type': 'view', 'schema': 'sch'}]) - - -def test_truncate_suggests_tables_and_schemas(): - suggestions = suggest_type('TRUNCATE ', 'TRUNCATE ') - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'table', 'schema': []}, - {'type': 'schema'}]) - - -def test_truncate_suggests_qualified_tables(): - suggestions = suggest_type('TRUNCATE sch.', 'TRUNCATE sch.') - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'table', 'schema': 'sch'}]) - - -def test_distinct_suggests_cols(): - suggestions = suggest_type('SELECT DISTINCT ', 'SELECT DISTINCT ') - assert suggestions == [{'type': 'column', 'tables': []}] - - -def test_col_comma_suggests_cols(): - suggestions = suggest_type('SELECT a, b, FROM tbl', 'SELECT a, b,') - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'alias', 'aliases': ['tbl']}, - {'type': 'column', 'tables': [(None, 'tbl', None)]}, - {'type': 'function', 'schema': []}, - {'type': 'keyword'}, - ]) - - -def test_table_comma_suggests_tables_and_schemas(): - suggestions = suggest_type('SELECT a, b FROM tbl1, ', - 'SELECT a, b FROM tbl1, ') - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'table', 'schema': []}, - {'type': 'view', 'schema': []}, - {'type': 'schema'}]) - - -def test_into_suggests_tables_and_schemas(): - suggestion = suggest_type('INSERT INTO ', 'INSERT INTO ') - assert sorted_dicts(suggestion) == sorted_dicts([ - {'type': 'table', 'schema': []}, - {'type': 'view', 'schema': []}, - {'type': 'schema'}]) - - -def test_insert_into_lparen_suggests_cols(): - suggestions = suggest_type('INSERT INTO abc (', 'INSERT INTO abc (') - assert suggestions == [{'type': 'column', 'tables': [(None, 'abc', None)]}] - - -def test_insert_into_lparen_partial_text_suggests_cols(): - suggestions = suggest_type('INSERT INTO abc (i', 'INSERT INTO abc (i') - assert suggestions == [{'type': 'column', 'tables': [(None, 'abc', None)]}] - - -def test_insert_into_lparen_comma_suggests_cols(): - suggestions = suggest_type('INSERT INTO abc (id,', 'INSERT INTO abc (id,') - assert suggestions == [{'type': 'column', 'tables': [(None, 'abc', None)]}] - - -def test_partially_typed_col_name_suggests_col_names(): - suggestions = suggest_type('SELECT * FROM tabl WHERE col_n', - 'SELECT * FROM tabl WHERE col_n') - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'alias', 'aliases': ['tabl']}, - {'type': 'column', 'tables': [(None, 'tabl', None)]}, - {'type': 'function', 'schema': []}, - {'type': 'keyword'}, - ]) - - -def test_dot_suggests_cols_of_a_table_or_schema_qualified_table(): - suggestions = suggest_type('SELECT tabl. FROM tabl', 'SELECT tabl.') - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'column', 'tables': [(None, 'tabl', None)]}, - {'type': 'table', 'schema': 'tabl'}, - {'type': 'view', 'schema': 'tabl'}, - {'type': 'function', 'schema': 'tabl'}]) - - -def test_dot_suggests_cols_of_an_alias(): - suggestions = suggest_type('SELECT t1. FROM tabl1 t1, tabl2 t2', - 'SELECT t1.') - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'table', 'schema': 't1'}, - {'type': 'view', 'schema': 't1'}, - {'type': 'column', 'tables': [(None, 'tabl1', 't1')]}, - {'type': 'function', 'schema': 't1'}]) - - -def test_dot_col_comma_suggests_cols_or_schema_qualified_table(): - suggestions = suggest_type('SELECT t1.a, t2. FROM tabl1 t1, tabl2 t2', - 'SELECT t1.a, t2.') - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'column', 'tables': [(None, 'tabl2', 't2')]}, - {'type': 'table', 'schema': 't2'}, - {'type': 'view', 'schema': 't2'}, - {'type': 'function', 'schema': 't2'}]) - - -@pytest.mark.parametrize('expression', [ - 'SELECT * FROM (', - 'SELECT * FROM foo WHERE EXISTS (', - 'SELECT * FROM foo WHERE bar AND NOT EXISTS (', - 'SELECT 1 AS', -]) -def test_sub_select_suggests_keyword(expression): - suggestion = suggest_type(expression, expression) - assert suggestion == [{'type': 'keyword'}] - - -@pytest.mark.parametrize('expression', [ - 'SELECT * FROM (S', - 'SELECT * FROM foo WHERE EXISTS (S', - 'SELECT * FROM foo WHERE bar AND NOT EXISTS (S', -]) -def test_sub_select_partial_text_suggests_keyword(expression): - suggestion = suggest_type(expression, expression) - assert suggestion == [{'type': 'keyword'}] - - -def test_outer_table_reference_in_exists_subquery_suggests_columns(): - q = 'SELECT * FROM foo f WHERE EXISTS (SELECT 1 FROM bar WHERE f.' - suggestions = suggest_type(q, q) - assert suggestions == [ - {'type': 'column', 'tables': [(None, 'foo', 'f')]}, - {'type': 'table', 'schema': 'f'}, - {'type': 'view', 'schema': 'f'}, - {'type': 'function', 'schema': 'f'}] - - -@pytest.mark.parametrize('expression', [ - 'SELECT * FROM (SELECT * FROM ', - 'SELECT * FROM foo WHERE EXISTS (SELECT * FROM ', - 'SELECT * FROM foo WHERE bar AND NOT EXISTS (SELECT * FROM ', -]) -def test_sub_select_table_name_completion(expression): - suggestion = suggest_type(expression, expression) - assert sorted_dicts(suggestion) == sorted_dicts([ - {'type': 'table', 'schema': []}, - {'type': 'view', 'schema': []}, - {'type': 'schema'}]) - - -def test_sub_select_col_name_completion(): - suggestions = suggest_type('SELECT * FROM (SELECT FROM abc', - 'SELECT * FROM (SELECT ') - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'alias', 'aliases': ['abc']}, - {'type': 'column', 'tables': [(None, 'abc', None)]}, - {'type': 'function', 'schema': []}, - {'type': 'keyword'}, - ]) - - -@pytest.mark.xfail -def test_sub_select_multiple_col_name_completion(): - suggestions = suggest_type('SELECT * FROM (SELECT a, FROM abc', - 'SELECT * FROM (SELECT a, ') - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'column', 'tables': [(None, 'abc', None)]}, - {'type': 'function', 'schema': []}]) - - -def test_sub_select_dot_col_name_completion(): - suggestions = suggest_type('SELECT * FROM (SELECT t. FROM tabl t', - 'SELECT * FROM (SELECT t.') - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'column', 'tables': [(None, 'tabl', 't')]}, - {'type': 'table', 'schema': 't'}, - {'type': 'view', 'schema': 't'}, - {'type': 'function', 'schema': 't'}]) - - -@pytest.mark.parametrize('join_type', ['', 'INNER', 'LEFT', 'RIGHT OUTER']) -@pytest.mark.parametrize('tbl_alias', ['', 'foo']) -def test_join_suggests_tables_and_schemas(tbl_alias, join_type): - text = 'SELECT * FROM abc {0} {1} JOIN '.format(tbl_alias, join_type) - suggestion = suggest_type(text, text) - assert sorted_dicts(suggestion) == sorted_dicts([ - {'type': 'table', 'schema': []}, - {'type': 'view', 'schema': []}, - {'type': 'schema'}]) - - -@pytest.mark.parametrize('sql', [ - 'SELECT * FROM abc a JOIN def d ON a.', - 'SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.', -]) -def test_join_alias_dot_suggests_cols1(sql): - suggestions = suggest_type(sql, sql) - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'column', 'tables': [(None, 'abc', 'a')]}, - {'type': 'table', 'schema': 'a'}, - {'type': 'view', 'schema': 'a'}, - {'type': 'function', 'schema': 'a'}]) - - -@pytest.mark.parametrize('sql', [ - 'SELECT * FROM abc a JOIN def d ON a.id = d.', - 'SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.id2 = d.', -]) -def test_join_alias_dot_suggests_cols2(sql): - suggestions = suggest_type(sql, sql) - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'column', 'tables': [(None, 'def', 'd')]}, - {'type': 'table', 'schema': 'd'}, - {'type': 'view', 'schema': 'd'}, - {'type': 'function', 'schema': 'd'}]) - - -@pytest.mark.parametrize('sql', [ - 'select a.x, b.y from abc a join bcd b on ', - 'select a.x, b.y from abc a join bcd b on a.id = b.id OR ', -]) -def test_on_suggests_aliases(sql): - suggestions = suggest_type(sql, sql) - assert suggestions == [{'type': 'alias', 'aliases': ['a', 'b']}] - - -@pytest.mark.parametrize('sql', [ - 'select abc.x, bcd.y from abc join bcd on ', - 'select abc.x, bcd.y from abc join bcd on abc.id = bcd.id AND ', -]) -def test_on_suggests_tables(sql): - suggestions = suggest_type(sql, sql) - assert suggestions == [{'type': 'alias', 'aliases': ['abc', 'bcd']}] - - -@pytest.mark.parametrize('sql', [ - 'select a.x, b.y from abc a join bcd b on a.id = ', - 'select a.x, b.y from abc a join bcd b on a.id = b.id AND a.id2 = ', -]) -def test_on_suggests_aliases_right_side(sql): - suggestions = suggest_type(sql, sql) - assert suggestions == [{'type': 'alias', 'aliases': ['a', 'b']}] - - -@pytest.mark.parametrize('sql', [ - 'select abc.x, bcd.y from abc join bcd on ', - 'select abc.x, bcd.y from abc join bcd on abc.id = bcd.id and ', -]) -def test_on_suggests_tables_right_side(sql): - suggestions = suggest_type(sql, sql) - assert suggestions == [{'type': 'alias', 'aliases': ['abc', 'bcd']}] - - -@pytest.mark.parametrize('col_list', ['', 'col1, ']) -def test_join_using_suggests_common_columns(col_list): - text = 'select * from abc inner join def using (' + col_list - assert suggest_type(text, text) == [ - {'type': 'column', - 'tables': [(None, 'abc', None), (None, 'def', None)], - 'drop_unique': True}] - -@pytest.mark.parametrize('sql', [ - 'SELECT * FROM abc a JOIN def d ON a.id = d.id JOIN ghi g ON g.', - 'SELECT * FROM abc a JOIN def d ON a.id = d.id AND a.id2 = d.id2 JOIN ghi g ON d.id = g.id AND g.', -]) -def test_two_join_alias_dot_suggests_cols1(sql): - suggestions = suggest_type(sql, sql) - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'column', 'tables': [(None, 'ghi', 'g')]}, - {'type': 'table', 'schema': 'g'}, - {'type': 'view', 'schema': 'g'}, - {'type': 'function', 'schema': 'g'}]) - -def test_2_statements_2nd_current(): - suggestions = suggest_type('select * from a; select * from ', - 'select * from a; select * from ') - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'table', 'schema': []}, - {'type': 'view', 'schema': []}, - {'type': 'schema'}]) - - suggestions = suggest_type('select * from a; select from b', - 'select * from a; select ') - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'alias', 'aliases': ['b']}, - {'type': 'column', 'tables': [(None, 'b', None)]}, - {'type': 'function', 'schema': []}, - {'type': 'keyword'}, - ]) - - # Should work even if first statement is invalid - suggestions = suggest_type('select * from; select * from ', - 'select * from; select * from ') - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'table', 'schema': []}, - {'type': 'view', 'schema': []}, - {'type': 'schema'}]) - - -def test_2_statements_1st_current(): - suggestions = suggest_type('select * from ; select * from b', - 'select * from ') - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'table', 'schema': []}, - {'type': 'view', 'schema': []}, - {'type': 'schema'}]) - - suggestions = suggest_type('select from a; select * from b', - 'select ') - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'alias', 'aliases': ['a']}, - {'type': 'column', 'tables': [(None, 'a', None)]}, - {'type': 'function', 'schema': []}, - {'type': 'keyword'}, - ]) - - -def test_3_statements_2nd_current(): - suggestions = suggest_type('select * from a; select * from ; select * from c', - 'select * from a; select * from ') - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'table', 'schema': []}, - {'type': 'view', 'schema': []}, - {'type': 'schema'}]) - - suggestions = suggest_type('select * from a; select from b; select * from c', - 'select * from a; select ') - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'alias', 'aliases': ['b']}, - {'type': 'column', 'tables': [(None, 'b', None)]}, - {'type': 'function', 'schema': []}, - {'type': 'keyword'}, - ]) - - -def test_create_db_with_template(): - suggestions = suggest_type('create database foo with template ', - 'create database foo with template ') - - assert sorted_dicts(suggestions) == sorted_dicts([{'type': 'database'}]) - - -@pytest.mark.parametrize('initial_text', ['', ' ', '\t \t']) -def test_specials_included_for_initial_completion(initial_text): - suggestions = suggest_type(initial_text, initial_text) - - assert sorted_dicts(suggestions) == \ - sorted_dicts([{'type': 'keyword'}, {'type': 'special'}]) - - -def test_specials_not_included_after_initial_token(): - suggestions = suggest_type('create table foo (dt d', - 'create table foo (dt d') - - assert sorted_dicts(suggestions) == sorted_dicts([{'type': 'keyword'}]) - - -def test_drop_schema_qualified_table_suggests_only_tables(): - text = 'DROP TABLE schema_name.table_name' - suggestions = suggest_type(text, text) - assert suggestions == [{'type': 'table', 'schema': 'schema_name'}] - - -@pytest.mark.parametrize('text', [',', ' ,', 'sel ,']) -def test_handle_pre_completion_comma_gracefully(text): - suggestions = suggest_type(text, text) - - assert iter(suggestions) - - -def test_cross_join(): - text = 'select * from v1 cross join v2 JOIN v1.id, ' - suggestions = suggest_type(text, text) - assert sorted_dicts(suggestions) == sorted_dicts([ - {'type': 'table', 'schema': []}, - {'type': 'view', 'schema': []}, - {'type': 'schema'}]) - - -@pytest.mark.parametrize('expression', [ - 'SELECT 1 AS ', - 'SELECT 1 FROM tabl AS ', -]) -def test_after_as(expression): - suggestions = suggest_type(expression, expression) - assert set(suggestions) == set() - - -@pytest.mark.parametrize('expression', [ - '\\. ', - 'select 1; \\. ', - 'select 1;\\. ', - 'select 1 ; \\. ', - 'source ', - 'truncate table test; source ', - 'truncate table test ; source ', - 'truncate table test;source ', -]) -def test_source_is_file(expression): - suggestions = suggest_type(expression, expression) - assert suggestions == [{'type': 'file_name'}] - - -@pytest.mark.parametrize("expression", [ - "\\f ", -]) -def test_favorite_name_suggestion(expression): - suggestions = suggest_type(expression, expression) - assert suggestions == [{'type': 'favoritequery'}] - - -def test_order_by(): - text = 'select * from foo order by ' - suggestions = suggest_type(text, text) - assert suggestions == [{'tables': [(None, 'foo', None)], 'type': 'column'}] - - -def test_quoted_where(): - text = "'where i=';" - suggestions = suggest_type(text, text) - assert suggestions == [{'type': 'keyword'}] diff --git a/test/test_completion_refresher.py b/test/test_completion_refresher.py deleted file mode 100644 index 31359cf3..00000000 --- a/test/test_completion_refresher.py +++ /dev/null @@ -1,88 +0,0 @@ -import time -import pytest -from unittest.mock import Mock, patch - - -@pytest.fixture -def refresher(): - from mycli.completion_refresher import CompletionRefresher - return CompletionRefresher() - - -def test_ctor(refresher): - """Refresher object should contain a few handlers. - - :param refresher: - :return: - - """ - assert len(refresher.refreshers) > 0 - actual_handlers = list(refresher.refreshers.keys()) - expected_handlers = ['databases', 'schemata', 'tables', 'users', 'functions', - 'special_commands', 'show_commands', 'keywords'] - assert expected_handlers == actual_handlers - - -def test_refresh_called_once(refresher): - """ - - :param refresher: - :return: - """ - callbacks = Mock() - sqlexecute = Mock() - - with patch.object(refresher, '_bg_refresh') as bg_refresh: - actual = refresher.refresh(sqlexecute, callbacks) - time.sleep(1) # Wait for the thread to work. - assert len(actual) == 1 - assert len(actual[0]) == 4 - assert actual[0][3] == 'Auto-completion refresh started in the background.' - bg_refresh.assert_called_with(sqlexecute, callbacks, {}) - - -def test_refresh_called_twice(refresher): - """If refresh is called a second time, it should be restarted. - - :param refresher: - :return: - - """ - callbacks = Mock() - - sqlexecute = Mock() - - def dummy_bg_refresh(*args): - time.sleep(3) # seconds - - refresher._bg_refresh = dummy_bg_refresh - - actual1 = refresher.refresh(sqlexecute, callbacks) - time.sleep(1) # Wait for the thread to work. - assert len(actual1) == 1 - assert len(actual1[0]) == 4 - assert actual1[0][3] == 'Auto-completion refresh started in the background.' - - actual2 = refresher.refresh(sqlexecute, callbacks) - time.sleep(1) # Wait for the thread to work. - assert len(actual2) == 1 - assert len(actual2[0]) == 4 - assert actual2[0][3] == 'Auto-completion refresh restarted.' - - -def test_refresh_with_callbacks(refresher): - """Callbacks must be called. - - :param refresher: - - """ - callbacks = [Mock()] - sqlexecute_class = Mock() - sqlexecute = Mock() - - with patch('mycli.completion_refresher.SQLExecute', sqlexecute_class): - # Set refreshers to 0: we're not testing refresh logic here - refresher.refreshers = {} - refresher.refresh(sqlexecute, callbacks) - time.sleep(1) # Wait for the thread to work. - assert (callbacks[0].call_count == 1) diff --git a/test/test_config.py b/test/test_config.py deleted file mode 100644 index 7f2b2442..00000000 --- a/test/test_config.py +++ /dev/null @@ -1,196 +0,0 @@ -"""Unit tests for the mycli.config module.""" -from io import BytesIO, StringIO, TextIOWrapper -import os -import struct -import sys -import tempfile -import pytest - -from mycli.config import (get_mylogin_cnf_path, open_mylogin_cnf, - read_and_decrypt_mylogin_cnf, read_config_file, - str_to_bool, strip_matching_quotes) - -LOGIN_PATH_FILE = os.path.abspath(os.path.join(os.path.dirname(__file__), - 'mylogin.cnf')) - - -def open_bmylogin_cnf(name): - """Open contents of *name* in a BytesIO buffer.""" - with open(name, 'rb') as f: - buf = BytesIO() - buf.write(f.read()) - return buf - -def test_read_mylogin_cnf(): - """Tests that a login path file can be read and decrypted.""" - mylogin_cnf = open_mylogin_cnf(LOGIN_PATH_FILE) - - assert isinstance(mylogin_cnf, TextIOWrapper) - - contents = mylogin_cnf.read() - for word in ('[test]', 'user', 'password', 'host', 'port'): - assert word in contents - - -def test_decrypt_blank_mylogin_cnf(): - """Test that a blank login path file is handled correctly.""" - mylogin_cnf = read_and_decrypt_mylogin_cnf(BytesIO()) - assert mylogin_cnf is None - - -def test_corrupted_login_key(): - """Test that a corrupted login path key is handled correctly.""" - buf = open_bmylogin_cnf(LOGIN_PATH_FILE) - - # Skip past the unused bytes - buf.seek(4) - - # Write null bytes over half the login key - buf.write(b'\0\0\0\0\0\0\0\0\0\0') - - buf.seek(0) - mylogin_cnf = read_and_decrypt_mylogin_cnf(buf) - - assert mylogin_cnf is None - - -def test_corrupted_pad(): - """Tests that a login path file with a corrupted pad is partially read.""" - buf = open_bmylogin_cnf(LOGIN_PATH_FILE) - - # Skip past the login key - buf.seek(24) - - # Skip option group - len_buf = buf.read(4) - cipher_len, = struct.unpack(" pager - output( - monkeypatch, - terminal_size=(5, 10), - testdata=testdata, - explicit_pager=False, - expect_pager=True - ) - # User didn't set pager, output fits screen -> no pager - output( - monkeypatch, - terminal_size=(20, 20), - testdata=testdata, - explicit_pager=False, - expect_pager=False - ) - # User manually configured pager, output doesn't fit screen -> pager - output( - monkeypatch, - terminal_size=(5, 10), - testdata=testdata, - explicit_pager=True, - expect_pager=True - ) - # User manually configured pager, output fit screen -> pager - output( - monkeypatch, - terminal_size=(20, 20), - testdata=testdata, - explicit_pager=True, - expect_pager=True - ) - - SPECIAL_COMMANDS['nopager'].handler() - output( - monkeypatch, - terminal_size=(5, 10), - testdata=testdata, - explicit_pager=False, - expect_pager=False - ) - SPECIAL_COMMANDS['pager'].handler('') - - -def test_reserved_space_is_integer(monkeypatch): - """Make sure that reserved space is returned as an integer.""" - def stub_terminal_size(): - return (5, 5) - - with monkeypatch.context() as m: - m.setattr(shutil, 'get_terminal_size', stub_terminal_size) - mycli = MyCli() - assert isinstance(mycli.get_reserved_space(), int) - - -def test_list_dsn(): - runner = CliRunner() - # keep Windows from locking the file with delete=False - with NamedTemporaryFile(mode="w",delete=False) as myclirc: - myclirc.write(dedent("""\ - [alias_dsn] - test = mysql://test/test - """)) - myclirc.flush() - args = ['--list-dsn', '--myclirc', myclirc.name] - result = runner.invoke(cli, args=args) - assert result.output == "test\n" - result = runner.invoke(cli, args=args + ['--verbose']) - assert result.output == "test : mysql://test/test\n" - - # delete=False means we should try to clean up - try: - if os.path.exists(myclirc.name): - os.remove(myclirc.name) - except Exception as e: - print(f"An error occurred while attempting to delete the file: {e}") - - - - -def test_prettify_statement(): - statement = 'SELECT 1' - m = MyCli() - pretty_statement = m.handle_prettify_binding(statement) - assert pretty_statement == 'SELECT\n 1;' - - -def test_unprettify_statement(): - statement = 'SELECT\n 1' - m = MyCli() - unpretty_statement = m.handle_unprettify_binding(statement) - assert unpretty_statement == 'SELECT 1;' - - -def test_list_ssh_config(): - runner = CliRunner() - # keep Windows from locking the file with delete=False - with NamedTemporaryFile(mode="w",delete=False) as ssh_config: - ssh_config.write(dedent("""\ - Host test - Hostname test.example.com - User joe - Port 22222 - IdentityFile ~/.ssh/gateway - """)) - ssh_config.flush() - args = ['--list-ssh-config', '--ssh-config-path', ssh_config.name] - result = runner.invoke(cli, args=args) - assert "test\n" in result.output - result = runner.invoke(cli, args=args + ['--verbose']) - assert "test : test.example.com\n" in result.output - - # delete=False means we should try to clean up - try: - if os.path.exists(ssh_config.name): - os.remove(ssh_config.name) - except Exception as e: - print(f"An error occurred while attempting to delete the file: {e}") - - -def test_dsn(monkeypatch): - # Setup classes to mock mycli.main.MyCli - class Formatter: - format_name = None - - class Logger: - def debug(self, *args, **args_dict): - pass - - def warning(self, *args, **args_dict): - pass - - class MockMyCli: - config = {'alias_dsn': {}} - - def __init__(self, **args): - self.logger = Logger() - self.destructive_warning = False - self.formatter = Formatter() - - def connect(self, **args): - MockMyCli.connect_args = args - - def run_query(self, query, new_line=True): - pass - - import mycli.main - monkeypatch.setattr(mycli.main, 'MyCli', MockMyCli) - runner = CliRunner() - - # When a user supplies a DSN as database argument to mycli, - # use these values. - result = runner.invoke(mycli.main.cli, args=[ - "mysql://dsn_user:dsn_passwd@dsn_host:1/dsn_database"] - ) - assert result.exit_code == 0, result.output + " " + str(result.exception) - assert \ - MockMyCli.connect_args["user"] == "dsn_user" and \ - MockMyCli.connect_args["passwd"] == "dsn_passwd" and \ - MockMyCli.connect_args["host"] == "dsn_host" and \ - MockMyCli.connect_args["port"] == 1 and \ - MockMyCli.connect_args["database"] == "dsn_database" - - MockMyCli.connect_args = None - - # When a use supplies a DSN as database argument to mycli, - # and used command line arguments, use the command line - # arguments. - result = runner.invoke(mycli.main.cli, args=[ - "mysql://dsn_user:dsn_passwd@dsn_host:2/dsn_database", - "--user", "arg_user", - "--password", "arg_password", - "--host", "arg_host", - "--port", "3", - "--database", "arg_database", - ]) - assert result.exit_code == 0, result.output + " " + str(result.exception) - assert \ - MockMyCli.connect_args["user"] == "arg_user" and \ - MockMyCli.connect_args["passwd"] == "arg_password" and \ - MockMyCli.connect_args["host"] == "arg_host" and \ - MockMyCli.connect_args["port"] == 3 and \ - MockMyCli.connect_args["database"] == "arg_database" - - MockMyCli.config = { - 'alias_dsn': { - 'test': 'mysql://alias_dsn_user:alias_dsn_passwd@alias_dsn_host:4/alias_dsn_database' - } - } - MockMyCli.connect_args = None - - # When a user uses a DSN from the configuration file (alias_dsn), - # use these values. - result = runner.invoke(cli, args=['--dsn', 'test']) - assert result.exit_code == 0, result.output + " " + str(result.exception) - assert \ - MockMyCli.connect_args["user"] == "alias_dsn_user" and \ - MockMyCli.connect_args["passwd"] == "alias_dsn_passwd" and \ - MockMyCli.connect_args["host"] == "alias_dsn_host" and \ - MockMyCli.connect_args["port"] == 4 and \ - MockMyCli.connect_args["database"] == "alias_dsn_database" - - MockMyCli.config = { - 'alias_dsn': { - 'test': 'mysql://alias_dsn_user:alias_dsn_passwd@alias_dsn_host:4/alias_dsn_database' - } - } - MockMyCli.connect_args = None - - # When a user uses a DSN from the configuration file (alias_dsn) - # and used command line arguments, use the command line arguments. - result = runner.invoke(cli, args=[ - '--dsn', 'test', '', - "--user", "arg_user", - "--password", "arg_password", - "--host", "arg_host", - "--port", "5", - "--database", "arg_database", - ]) - assert result.exit_code == 0, result.output + " " + str(result.exception) - assert \ - MockMyCli.connect_args["user"] == "arg_user" and \ - MockMyCli.connect_args["passwd"] == "arg_password" and \ - MockMyCli.connect_args["host"] == "arg_host" and \ - MockMyCli.connect_args["port"] == 5 and \ - MockMyCli.connect_args["database"] == "arg_database" - - # Use a DSN without password - result = runner.invoke(mycli.main.cli, args=[ - "mysql://dsn_user@dsn_host:6/dsn_database"] - ) - assert result.exit_code == 0, result.output + " " + str(result.exception) - assert \ - MockMyCli.connect_args["user"] == "dsn_user" and \ - MockMyCli.connect_args["passwd"] is None and \ - MockMyCli.connect_args["host"] == "dsn_host" and \ - MockMyCli.connect_args["port"] == 6 and \ - MockMyCli.connect_args["database"] == "dsn_database" - - -def test_ssh_config(monkeypatch): - # Setup classes to mock mycli.main.MyCli - class Formatter: - format_name = None - - class Logger: - def debug(self, *args, **args_dict): - pass - - def warning(self, *args, **args_dict): - pass - - class MockMyCli: - config = {'alias_dsn': {}} - - def __init__(self, **args): - self.logger = Logger() - self.destructive_warning = False - self.formatter = Formatter() - - def connect(self, **args): - MockMyCli.connect_args = args - - def run_query(self, query, new_line=True): - pass - - import mycli.main - monkeypatch.setattr(mycli.main, 'MyCli', MockMyCli) - runner = CliRunner() - - # Setup temporary configuration - # keep Windows from locking the file with delete=False - with NamedTemporaryFile(mode="w",delete=False) as ssh_config: - ssh_config.write(dedent("""\ - Host test - Hostname test.example.com - User joe - Port 22222 - IdentityFile ~/.ssh/gateway - """)) - ssh_config.flush() - - # When a user supplies a ssh config. - result = runner.invoke(mycli.main.cli, args=[ - "--ssh-config-path", - ssh_config.name, - "--ssh-config-host", - "test" - ]) - assert result.exit_code == 0, result.output + \ - " " + str(result.exception) - assert \ - MockMyCli.connect_args["ssh_user"] == "joe" and \ - MockMyCli.connect_args["ssh_host"] == "test.example.com" and \ - MockMyCli.connect_args["ssh_port"] == 22222 and \ - MockMyCli.connect_args["ssh_key_filename"] == os.path.expanduser( - "~") + "/.ssh/gateway" - - # When a user supplies a ssh config host as argument to mycli, - # and used command line arguments, use the command line - # arguments. - result = runner.invoke(mycli.main.cli, args=[ - "--ssh-config-path", - ssh_config.name, - "--ssh-config-host", - "test", - "--ssh-user", "arg_user", - "--ssh-host", "arg_host", - "--ssh-port", "3", - "--ssh-key-filename", "/path/to/key" - ]) - assert result.exit_code == 0, result.output + \ - " " + str(result.exception) - assert \ - MockMyCli.connect_args["ssh_user"] == "arg_user" and \ - MockMyCli.connect_args["ssh_host"] == "arg_host" and \ - MockMyCli.connect_args["ssh_port"] == 3 and \ - MockMyCli.connect_args["ssh_key_filename"] == "/path/to/key" - - # delete=False means we should try to clean up - try: - if os.path.exists(ssh_config.name): - os.remove(ssh_config.name) - except Exception as e: - print(f"An error occurred while attempting to delete the file: {e}") - - -@dbtest -def test_init_command_arg(executor): - init_command = "set sql_select_limit=1000" - sql = 'show variables like "sql_select_limit";' - runner = CliRunner() - result = runner.invoke( - cli, args=CLI_ARGS + ["--init-command", init_command], input=sql - ) - - expected = "sql_select_limit\t1000\n" - assert result.exit_code == 0 - assert expected in result.output - - -@dbtest -def test_init_command_multiple_arg(executor): - init_command = 'set sql_select_limit=2000; set max_join_size=20000' - sql = ( - 'show variables like "sql_select_limit";\n' - 'show variables like "max_join_size"' - ) - runner = CliRunner() - result = runner.invoke( - cli, args=CLI_ARGS + ['--init-command', init_command], input=sql - ) - - expected_sql_select_limit = 'sql_select_limit\t2000\n' - expected_max_join_size = 'max_join_size\t20000\n' - - assert result.exit_code == 0 - assert expected_sql_select_limit in result.output - assert expected_max_join_size in result.output diff --git a/test/test_naive_completion.py b/test/test_naive_completion.py deleted file mode 100644 index 0bc3bf87..00000000 --- a/test/test_naive_completion.py +++ /dev/null @@ -1,61 +0,0 @@ -import pytest -from prompt_toolkit.completion import Completion -from prompt_toolkit.document import Document - - -@pytest.fixture -def completer(): - import mycli.sqlcompleter as sqlcompleter - return sqlcompleter.SQLCompleter(smart_completion=False) - - -@pytest.fixture -def complete_event(): - from unittest.mock import Mock - return Mock() - - -def test_empty_string_completion(completer, complete_event): - text = '' - position = 0 - result = list(completer.get_completions( - Document(text=text, cursor_position=position), - complete_event)) - assert result == list(map(Completion, completer.all_completions)) - - -def test_select_keyword_completion(completer, complete_event): - text = 'SEL' - position = len('SEL') - result = list(completer.get_completions( - Document(text=text, cursor_position=position), - complete_event)) - assert result == list([Completion(text='SELECT', start_position=-3)]) - - -def test_function_name_completion(completer, complete_event): - text = 'SELECT MA' - position = len('SELECT MA') - result = list(completer.get_completions( - Document(text=text, cursor_position=position), - complete_event)) - assert sorted(x.text for x in result) == ["MASTER", "MAX"] - - -def test_column_name_completion(completer, complete_event): - text = 'SELECT FROM users' - position = len('SELECT ') - result = list(completer.get_completions( - Document(text=text, cursor_position=position), - complete_event)) - assert result == list(map(Completion, completer.all_completions)) - - -def test_special_name_completion(completer, complete_event): - text = '\\' - position = len('\\') - result = set(completer.get_completions( - Document(text=text, cursor_position=position), - complete_event)) - # Special commands will NOT be suggested during naive completion mode. - assert result == set() diff --git a/test/test_parseutils.py b/test/test_parseutils.py deleted file mode 100644 index 920a08db..00000000 --- a/test/test_parseutils.py +++ /dev/null @@ -1,190 +0,0 @@ -import pytest -from mycli.packages.parseutils import ( - extract_tables, query_starts_with, queries_start_with, is_destructive, query_has_where_clause, - is_dropping_database) - - -def test_empty_string(): - tables = extract_tables('') - assert tables == [] - - -def test_simple_select_single_table(): - tables = extract_tables('select * from abc') - assert tables == [(None, 'abc', None)] - - -def test_simple_select_single_table_schema_qualified(): - tables = extract_tables('select * from abc.def') - assert tables == [('abc', 'def', None)] - - -def test_simple_select_multiple_tables(): - tables = extract_tables('select * from abc, def') - assert sorted(tables) == [(None, 'abc', None), (None, 'def', None)] - - -def test_simple_select_multiple_tables_schema_qualified(): - tables = extract_tables('select * from abc.def, ghi.jkl') - assert sorted(tables) == [('abc', 'def', None), ('ghi', 'jkl', None)] - - -def test_simple_select_with_cols_single_table(): - tables = extract_tables('select a,b from abc') - assert tables == [(None, 'abc', None)] - - -def test_simple_select_with_cols_single_table_schema_qualified(): - tables = extract_tables('select a,b from abc.def') - assert tables == [('abc', 'def', None)] - - -def test_simple_select_with_cols_multiple_tables(): - tables = extract_tables('select a,b from abc, def') - assert sorted(tables) == [(None, 'abc', None), (None, 'def', None)] - - -def test_simple_select_with_cols_multiple_tables_with_schema(): - tables = extract_tables('select a,b from abc.def, def.ghi') - assert sorted(tables) == [('abc', 'def', None), ('def', 'ghi', None)] - - -def test_select_with_hanging_comma_single_table(): - tables = extract_tables('select a, from abc') - assert tables == [(None, 'abc', None)] - - -def test_select_with_hanging_comma_multiple_tables(): - tables = extract_tables('select a, from abc, def') - assert sorted(tables) == [(None, 'abc', None), (None, 'def', None)] - - -def test_select_with_hanging_period_multiple_tables(): - tables = extract_tables('SELECT t1. FROM tabl1 t1, tabl2 t2') - assert sorted(tables) == [(None, 'tabl1', 't1'), (None, 'tabl2', 't2')] - - -def test_simple_insert_single_table(): - tables = extract_tables('insert into abc (id, name) values (1, "def")') - - # sqlparse mistakenly assigns an alias to the table - # assert tables == [(None, 'abc', None)] - assert tables == [(None, 'abc', 'abc')] - - -@pytest.mark.xfail -def test_simple_insert_single_table_schema_qualified(): - tables = extract_tables('insert into abc.def (id, name) values (1, "def")') - assert tables == [('abc', 'def', None)] - - -def test_simple_update_table(): - tables = extract_tables('update abc set id = 1') - assert tables == [(None, 'abc', None)] - - -def test_simple_update_table_with_schema(): - tables = extract_tables('update abc.def set id = 1') - assert tables == [('abc', 'def', None)] - - -def test_join_table(): - tables = extract_tables('SELECT * FROM abc a JOIN def d ON a.id = d.num') - assert sorted(tables) == [(None, 'abc', 'a'), (None, 'def', 'd')] - - -def test_join_table_schema_qualified(): - tables = extract_tables( - 'SELECT * FROM abc.def x JOIN ghi.jkl y ON x.id = y.num') - assert tables == [('abc', 'def', 'x'), ('ghi', 'jkl', 'y')] - - -def test_join_as_table(): - tables = extract_tables('SELECT * FROM my_table AS m WHERE m.a > 5') - assert tables == [(None, 'my_table', 'm')] - - -def test_query_starts_with(): - query = 'USE test;' - assert query_starts_with(query, ('use', )) is True - - query = 'DROP DATABASE test;' - assert query_starts_with(query, ('use', )) is False - - -def test_query_starts_with_comment(): - query = '# comment\nUSE test;' - assert query_starts_with(query, ('use', )) is True - - -def test_queries_start_with(): - sql = ( - '# comment\n' - 'show databases;' - 'use foo;' - ) - assert queries_start_with(sql, ('show', 'select')) is True - assert queries_start_with(sql, ('use', 'drop')) is True - assert queries_start_with(sql, ('delete', 'update')) is False - - -def test_is_destructive(): - sql = ( - 'use test;\n' - 'show databases;\n' - 'drop database foo;' - ) - assert is_destructive(sql) is True - - -def test_is_destructive_update_with_where_clause(): - sql = ( - 'use test;\n' - 'show databases;\n' - 'UPDATE test SET x = 1 WHERE id = 1;' - ) - assert is_destructive(sql) is False - - -def test_is_destructive_update_without_where_clause(): - sql = ( - 'use test;\n' - 'show databases;\n' - 'UPDATE test SET x = 1;' - ) - assert is_destructive(sql) is True - - -@pytest.mark.parametrize( - ('sql', 'has_where_clause'), - [ - ('update test set dummy = 1;', False), - ('update test set dummy = 1 where id = 1);', True), - ], -) -def test_query_has_where_clause(sql, has_where_clause): - assert query_has_where_clause(sql) is has_where_clause - - -@pytest.mark.parametrize( - ('sql', 'dbname', 'is_dropping'), - [ - ('select bar from foo', 'foo', False), - ('drop database "foo";', '`foo`', True), - ('drop schema foo', 'foo', True), - ('drop schema foo', 'bar', False), - ('drop database bar', 'foo', False), - ('drop database foo', None, False), - ('drop database foo; create database foo', 'foo', False), - ('drop database foo; create database bar', 'foo', True), - ('select bar from foo; drop database bazz', 'foo', False), - ('select bar from foo; drop database bazz', 'bazz', True), - ('-- dropping database \n ' - 'drop -- really dropping \n ' - 'schema abc -- now it is dropped', - 'abc', - True) - ] -) -def test_is_dropping_database(sql, dbname, is_dropping): - assert is_dropping_database(sql, dbname) == is_dropping diff --git a/test/test_plan.wiki b/test/test_plan.wiki deleted file mode 100644 index 43e90838..00000000 --- a/test/test_plan.wiki +++ /dev/null @@ -1,38 +0,0 @@ -= Gross Checks = - * [ ] Check connecting to a local database. - * [ ] Check connecting to a remote database. - * [ ] Check connecting to a database with a user/password. - * [ ] Check connecting to a non-existent database. - * [ ] Test changing the database. - - == PGExecute == - * [ ] Test successful execution given a cursor. - * [ ] Test unsuccessful execution with a syntax error. - * [ ] Test a series of executions with the same cursor without failure. - * [ ] Test a series of executions with the same cursor with failure. - * [ ] Test passing in a special command. - - == Naive Autocompletion == - * [ ] Input empty string, ask for completions - Everything. - * [ ] Input partial prefix, ask for completions - Stars with prefix. - * [ ] Input fully autocompleted string, ask for completions - Only full match - * [ ] Input non-existent prefix, ask for completions - nothing - * [ ] Input lowercase prefix - case insensitive completions - - == Smart Autocompletion == - * [ ] Input empty string and check if only keywords are returned. - * [ ] Input SELECT prefix and check if only columns and '*' are returned. - * [ ] Input SELECT blah - only keywords are returned. - * [ ] Input SELECT * FROM - Table names only - - == PGSpecial == - * [ ] Test \d - * [ ] Test \d tablename - * [ ] Test \d tablena* - * [ ] Test \d non-existent-tablename - * [ ] Test \d index - * [ ] Test \d sequence - * [ ] Test \d view - - == Exceptionals == - * [ ] Test the 'use' command to change db. diff --git a/test/test_prompt_utils.py b/test/test_prompt_utils.py deleted file mode 100644 index 2373fac8..00000000 --- a/test/test_prompt_utils.py +++ /dev/null @@ -1,11 +0,0 @@ -import click - -from mycli.packages.prompt_utils import confirm_destructive_query - - -def test_confirm_destructive_query_notty(): - stdin = click.get_text_stream('stdin') - assert stdin.isatty() is False - - sql = 'drop database foo;' - assert confirm_destructive_query(sql) is None diff --git a/test/test_smart_completion_public_schema_only.py b/test/test_smart_completion_public_schema_only.py deleted file mode 100644 index b60e67c5..00000000 --- a/test/test_smart_completion_public_schema_only.py +++ /dev/null @@ -1,389 +0,0 @@ -import pytest -from unittest.mock import patch -from prompt_toolkit.completion import Completion -from prompt_toolkit.document import Document -import mycli.packages.special.main as special - -metadata = { - 'users': ['id', 'email', 'first_name', 'last_name'], - 'orders': ['id', 'ordered_date', 'status'], - 'select': ['id', 'insert', 'ABC'], - 'réveillé': ['id', 'insert', 'ABC'] -} - - -@pytest.fixture -def completer(): - - import mycli.sqlcompleter as sqlcompleter - comp = sqlcompleter.SQLCompleter(smart_completion=True) - - tables, columns = [], [] - - for table, cols in metadata.items(): - tables.append((table,)) - columns.extend([(table, col) for col in cols]) - - comp.set_dbname('test') - comp.extend_schemata('test') - comp.extend_relations(tables, kind='tables') - comp.extend_columns(columns, kind='tables') - comp.extend_special_commands(special.COMMANDS) - - return comp - - -@pytest.fixture -def complete_event(): - from unittest.mock import Mock - return Mock() - - -def test_special_name_completion(completer, complete_event): - text = '\\d' - position = len('\\d') - result = completer.get_completions( - Document(text=text, cursor_position=position), - complete_event) - assert result == [Completion(text='\\dt', start_position=-2)] - - -def test_empty_string_completion(completer, complete_event): - text = '' - position = 0 - result = list( - completer.get_completions( - Document(text=text, cursor_position=position), - complete_event)) - assert list(map(Completion, completer.keywords + - completer.special_commands)) == result - - -def test_select_keyword_completion(completer, complete_event): - text = 'SEL' - position = len('SEL') - result = completer.get_completions( - Document(text=text, cursor_position=position), - complete_event) - assert list(result) == list([Completion(text='SELECT', start_position=-3)]) - - -def test_table_completion(completer, complete_event): - text = 'SELECT * FROM ' - position = len(text) - result = completer.get_completions( - Document(text=text, cursor_position=position), complete_event) - assert list(result) == list([ - Completion(text='users', start_position=0), - Completion(text='orders', start_position=0), - Completion(text='`select`', start_position=0), - Completion(text='`réveillé`', start_position=0), - ]) - - -def test_function_name_completion(completer, complete_event): - text = 'SELECT MA' - position = len('SELECT MA') - result = completer.get_completions( - Document(text=text, cursor_position=position), complete_event) - assert list(result) == list([Completion(text='MAX', start_position=-2), - Completion(text='MASTER', start_position=-2), - ]) - - -def test_suggested_column_names(completer, complete_event): - """Suggest column and function names when selecting from table. - - :param completer: - :param complete_event: - :return: - - """ - text = 'SELECT from users' - position = len('SELECT ') - result = list(completer.get_completions( - Document(text=text, cursor_position=position), - complete_event)) - assert result == list([ - Completion(text='*', start_position=0), - Completion(text='id', start_position=0), - Completion(text='email', start_position=0), - Completion(text='first_name', start_position=0), - Completion(text='last_name', start_position=0), - ] + - list(map(Completion, completer.functions)) + - [Completion(text='users', start_position=0)] + - list(map(Completion, completer.keywords))) - - -def test_suggested_column_names_in_function(completer, complete_event): - """Suggest column and function names when selecting multiple columns from - table. - - :param completer: - :param complete_event: - :return: - - """ - text = 'SELECT MAX( from users' - position = len('SELECT MAX(') - result = completer.get_completions( - Document(text=text, cursor_position=position), - complete_event) - assert list(result) == list([ - Completion(text='*', start_position=0), - Completion(text='id', start_position=0), - Completion(text='email', start_position=0), - Completion(text='first_name', start_position=0), - Completion(text='last_name', start_position=0)]) - - -def test_suggested_column_names_with_table_dot(completer, complete_event): - """Suggest column names on table name and dot. - - :param completer: - :param complete_event: - :return: - - """ - text = 'SELECT users. from users' - position = len('SELECT users.') - result = list(completer.get_completions( - Document(text=text, cursor_position=position), - complete_event)) - assert result == list([ - Completion(text='*', start_position=0), - Completion(text='id', start_position=0), - Completion(text='email', start_position=0), - Completion(text='first_name', start_position=0), - Completion(text='last_name', start_position=0)]) - - -def test_suggested_column_names_with_alias(completer, complete_event): - """Suggest column names on table alias and dot. - - :param completer: - :param complete_event: - :return: - - """ - text = 'SELECT u. from users u' - position = len('SELECT u.') - result = list(completer.get_completions( - Document(text=text, cursor_position=position), - complete_event)) - assert result == list([ - Completion(text='*', start_position=0), - Completion(text='id', start_position=0), - Completion(text='email', start_position=0), - Completion(text='first_name', start_position=0), - Completion(text='last_name', start_position=0)]) - - -def test_suggested_multiple_column_names(completer, complete_event): - """Suggest column and function names when selecting multiple columns from - table. - - :param completer: - :param complete_event: - :return: - - """ - text = 'SELECT id, from users u' - position = len('SELECT id, ') - result = list(completer.get_completions( - Document(text=text, cursor_position=position), - complete_event)) - assert result == list([ - Completion(text='*', start_position=0), - Completion(text='id', start_position=0), - Completion(text='email', start_position=0), - Completion(text='first_name', start_position=0), - Completion(text='last_name', start_position=0)] + - list(map(Completion, completer.functions)) + - [Completion(text='u', start_position=0)] + - list(map(Completion, completer.keywords))) - - -def test_suggested_multiple_column_names_with_alias(completer, complete_event): - """Suggest column names on table alias and dot when selecting multiple - columns from table. - - :param completer: - :param complete_event: - :return: - - """ - text = 'SELECT u.id, u. from users u' - position = len('SELECT u.id, u.') - result = list(completer.get_completions( - Document(text=text, cursor_position=position), - complete_event)) - assert result == list([ - Completion(text='*', start_position=0), - Completion(text='id', start_position=0), - Completion(text='email', start_position=0), - Completion(text='first_name', start_position=0), - Completion(text='last_name', start_position=0)]) - - -def test_suggested_multiple_column_names_with_dot(completer, complete_event): - """Suggest column names on table names and dot when selecting multiple - columns from table. - - :param completer: - :param complete_event: - :return: - - """ - text = 'SELECT users.id, users. from users u' - position = len('SELECT users.id, users.') - result = list(completer.get_completions( - Document(text=text, cursor_position=position), - complete_event)) - assert result == list([ - Completion(text='*', start_position=0), - Completion(text='id', start_position=0), - Completion(text='email', start_position=0), - Completion(text='first_name', start_position=0), - Completion(text='last_name', start_position=0)]) - - -def test_suggested_aliases_after_on(completer, complete_event): - text = 'SELECT u.name, o.id FROM users u JOIN orders o ON ' - position = len('SELECT u.name, o.id FROM users u JOIN orders o ON ') - result = list(completer.get_completions( - Document(text=text, cursor_position=position), - complete_event)) - assert result == list([ - Completion(text='u', start_position=0), - Completion(text='o', start_position=0), - ]) - - -def test_suggested_aliases_after_on_right_side(completer, complete_event): - text = 'SELECT u.name, o.id FROM users u JOIN orders o ON o.user_id = ' - position = len( - 'SELECT u.name, o.id FROM users u JOIN orders o ON o.user_id = ') - result = list(completer.get_completions( - Document(text=text, cursor_position=position), - complete_event)) - assert result == list([ - Completion(text='u', start_position=0), - Completion(text='o', start_position=0), - ]) - - -def test_suggested_tables_after_on(completer, complete_event): - text = 'SELECT users.name, orders.id FROM users JOIN orders ON ' - position = len('SELECT users.name, orders.id FROM users JOIN orders ON ') - result = list(completer.get_completions( - Document(text=text, cursor_position=position), - complete_event)) - assert result == list([ - Completion(text='users', start_position=0), - Completion(text='orders', start_position=0), - ]) - - -def test_suggested_tables_after_on_right_side(completer, complete_event): - text = 'SELECT users.name, orders.id FROM users JOIN orders ON orders.user_id = ' - position = len( - 'SELECT users.name, orders.id FROM users JOIN orders ON orders.user_id = ') - result = list(completer.get_completions( - Document(text=text, cursor_position=position), - complete_event)) - assert result == list([ - Completion(text='users', start_position=0), - Completion(text='orders', start_position=0), - ]) - - -def test_table_names_after_from(completer, complete_event): - text = 'SELECT * FROM ' - position = len('SELECT * FROM ') - result = list(completer.get_completions( - Document(text=text, cursor_position=position), - complete_event)) - assert result == list([ - Completion(text='users', start_position=0), - Completion(text='orders', start_position=0), - Completion(text='`select`', start_position=0), - Completion(text='`réveillé`', start_position=0), - ]) - - -def test_auto_escaped_col_names(completer, complete_event): - text = 'SELECT from `select`' - position = len('SELECT ') - result = list(completer.get_completions( - Document(text=text, cursor_position=position), - complete_event)) - assert result == [ - Completion(text='*', start_position=0), - Completion(text='id', start_position=0), - Completion(text='`insert`', start_position=0), - Completion(text='`ABC`', start_position=0), - ] + \ - list(map(Completion, completer.functions)) + \ - [Completion(text='select', start_position=0)] + \ - list(map(Completion, completer.keywords)) - - -def test_un_escaped_table_names(completer, complete_event): - text = 'SELECT from réveillé' - position = len('SELECT ') - result = list(completer.get_completions( - Document(text=text, cursor_position=position), - complete_event)) - assert result == list([ - Completion(text='*', start_position=0), - Completion(text='id', start_position=0), - Completion(text='`insert`', start_position=0), - Completion(text='`ABC`', start_position=0), - ] + - list(map(Completion, completer.functions)) + - [Completion(text='réveillé', start_position=0)] + - list(map(Completion, completer.keywords))) - - -def dummy_list_path(dir_name): - dirs = { - '/': [ - 'dir1', - 'file1.sql', - 'file2.sql', - ], - '/dir1': [ - 'subdir1', - 'subfile1.sql', - 'subfile2.sql', - ], - '/dir1/subdir1': [ - 'lastfile.sql', - ], - } - return dirs.get(dir_name, []) - - -@patch('mycli.packages.filepaths.list_path', new=dummy_list_path) -@pytest.mark.parametrize('text,expected', [ - # ('source ', [('~', 0), - # ('/', 0), - # ('.', 0), - # ('..', 0)]), - ('source /', [('dir1', 0), - ('file1.sql', 0), - ('file2.sql', 0)]), - ('source /dir1/', [('subdir1', 0), - ('subfile1.sql', 0), - ('subfile2.sql', 0)]), - ('source /dir1/subdir1/', [('lastfile.sql', 0)]), -]) -def test_file_name_completion(completer, complete_event, text, expected): - position = len(text) - result = list(completer.get_completions( - Document(text=text, cursor_position=position), - complete_event)) - expected = list((Completion(txt, pos) for txt, pos in expected)) - assert result == expected diff --git a/test/test_special_iocommands.py b/test/test_special_iocommands.py deleted file mode 100644 index d0ca45ff..00000000 --- a/test/test_special_iocommands.py +++ /dev/null @@ -1,341 +0,0 @@ -import os -import stat -import tempfile -from time import time -from unittest.mock import patch - -import pytest -from pymysql import ProgrammingError - -import mycli.packages.special - -from .utils import dbtest, db_connection, send_ctrl_c - - -def test_set_get_pager(): - mycli.packages.special.set_pager_enabled(True) - assert mycli.packages.special.is_pager_enabled() - mycli.packages.special.set_pager_enabled(False) - assert not mycli.packages.special.is_pager_enabled() - mycli.packages.special.set_pager('less') - assert os.environ['PAGER'] == "less" - mycli.packages.special.set_pager(False) - assert os.environ['PAGER'] == "less" - del os.environ['PAGER'] - mycli.packages.special.set_pager(False) - mycli.packages.special.disable_pager() - assert not mycli.packages.special.is_pager_enabled() - - -def test_set_get_timing(): - mycli.packages.special.set_timing_enabled(True) - assert mycli.packages.special.is_timing_enabled() - mycli.packages.special.set_timing_enabled(False) - assert not mycli.packages.special.is_timing_enabled() - - -def test_set_get_expanded_output(): - mycli.packages.special.set_expanded_output(True) - assert mycli.packages.special.is_expanded_output() - mycli.packages.special.set_expanded_output(False) - assert not mycli.packages.special.is_expanded_output() - - -def test_editor_command(): - assert mycli.packages.special.editor_command(r'hello\e') - assert mycli.packages.special.editor_command(r'\ehello') - assert not mycli.packages.special.editor_command(r'hello') - - assert mycli.packages.special.get_filename(r'\e filename') == "filename" - - os.environ['EDITOR'] = 'true' - os.environ['VISUAL'] = 'true' - # Set the editor to Notepad on Windows - if os.name != 'nt': - mycli.packages.special.open_external_editor(sql=r'select 1') == "select 1" - else: - pytest.skip('Skipping on Windows platform.') - - - -def test_tee_command(): - mycli.packages.special.write_tee(u"hello world") # write without file set - # keep Windows from locking the file with delete=False - with tempfile.NamedTemporaryFile(delete=False) as f: - mycli.packages.special.execute(None, u"tee " + f.name) - mycli.packages.special.write_tee(u"hello world") - if os.name=='nt': - assert f.read() == b"hello world\r\n" - else: - assert f.read() == b"hello world\n" - - mycli.packages.special.execute(None, u"tee -o " + f.name) - mycli.packages.special.write_tee(u"hello world") - f.seek(0) - if os.name=='nt': - assert f.read() == b"hello world\r\n" - else: - assert f.read() == b"hello world\n" - - mycli.packages.special.execute(None, u"notee") - mycli.packages.special.write_tee(u"hello world") - f.seek(0) - if os.name=='nt': - assert f.read() == b"hello world\r\n" - else: - assert f.read() == b"hello world\n" - - # remove temp file - # delete=False means we should try to clean up - try: - if os.path.exists(f.name): - os.remove(f.name) - except Exception as e: - print(f"An error occurred while attempting to delete the file: {e}") - - - -def test_tee_command_error(): - with pytest.raises(TypeError): - mycli.packages.special.execute(None, 'tee') - - with pytest.raises(OSError): - with tempfile.NamedTemporaryFile() as f: - os.chmod(f.name, stat.S_IRUSR | stat.S_IRGRP | stat.S_IROTH) - mycli.packages.special.execute(None, 'tee {}'.format(f.name)) - - -@dbtest - -@pytest.mark.skipif(os.name == "nt", reason="Bug: fails on Windows, needs fixing, singleton of FQ not working right") -def test_favorite_query(): - with db_connection().cursor() as cur: - query = u'select "✔"' - mycli.packages.special.execute(cur, u'\\fs check {0}'.format(query)) - assert next(mycli.packages.special.execute( - cur, u'\\f check'))[0] == "> " + query - - -def test_once_command(): - with pytest.raises(TypeError): - mycli.packages.special.execute(None, u"\\once") - - with pytest.raises(OSError): - mycli.packages.special.execute(None, u"\\once /proc/access-denied") - - mycli.packages.special.write_once(u"hello world") # write without file set - # keep Windows from locking the file with delete=False - with tempfile.NamedTemporaryFile(delete=False) as f: - mycli.packages.special.execute(None, u"\\once " + f.name) - mycli.packages.special.write_once(u"hello world") - if os.name=='nt': - assert f.read() == b"hello world\r\n" - else: - assert f.read() == b"hello world\n" - - mycli.packages.special.execute(None, u"\\once -o " + f.name) - mycli.packages.special.write_once(u"hello world line 1") - mycli.packages.special.write_once(u"hello world line 2") - f.seek(0) - if os.name=='nt': - assert f.read() == b"hello world line 1\r\nhello world line 2\r\n" - else: - assert f.read() == b"hello world line 1\nhello world line 2\n" - # delete=False means we should try to clean up - try: - if os.path.exists(f.name): - os.remove(f.name) - except Exception as e: - print(f"An error occurred while attempting to delete the file: {e}") - - -def test_pipe_once_command(): - with pytest.raises(IOError): - mycli.packages.special.execute(None, u"\\pipe_once") - - with pytest.raises(OSError): - mycli.packages.special.execute( - None, u"\\pipe_once /proc/access-denied") - - if os.name == 'nt': - mycli.packages.special.execute(None, u'\\pipe_once python -c "import sys; print(len(sys.stdin.read().strip()))"') - mycli.packages.special.write_once(u"hello world") - mycli.packages.special.unset_pipe_once_if_written() - else: - mycli.packages.special.execute(None, u"\\pipe_once wc") - mycli.packages.special.write_once(u"hello world") - mycli.packages.special.unset_pipe_once_if_written() - # how to assert on wc output? - - -def test_parseargfile(): - """Test that parseargfile expands the user directory.""" - expected = {'file': os.path.join(os.path.expanduser('~'), 'filename'), - 'mode': 'a'} - - if os.name=='nt': - assert expected == mycli.packages.special.iocommands.parseargfile( - '~\\filename') - else: - assert expected == mycli.packages.special.iocommands.parseargfile( - '~/filename') - - expected = {'file': os.path.join(os.path.expanduser('~'), 'filename'), - 'mode': 'w'} - if os.name=='nt': - assert expected == mycli.packages.special.iocommands.parseargfile( - '-o ~\\filename') - else: - assert expected == mycli.packages.special.iocommands.parseargfile( - '-o ~/filename') - - -def test_parseargfile_no_file(): - """Test that parseargfile raises a TypeError if there is no filename.""" - with pytest.raises(TypeError): - mycli.packages.special.iocommands.parseargfile('') - - with pytest.raises(TypeError): - mycli.packages.special.iocommands.parseargfile('-o ') - - -@dbtest -def test_watch_query_iteration(): - """Test that a single iteration of the result of `watch_query` executes - the desired query and returns the given results.""" - expected_value = "1" - query = "SELECT {0!s}".format(expected_value) - expected_title = '> {0!s}'.format(query) - with db_connection().cursor() as cur: - result = next(mycli.packages.special.iocommands.watch_query( - arg=query, cur=cur - )) - assert result[0] == expected_title - assert result[2][0] == expected_value - - -@dbtest -@pytest.mark.skipif(os.name == "nt", reason="Bug: Win handles this differently. May need to refactor watch_query to work for Win") -def test_watch_query_full(): - """Test that `watch_query`: - - * Returns the expected results. - * Executes the defined times inside the given interval, in this case with - a 0.3 seconds wait, it should execute 4 times inside a 1 seconds - interval. - * Stops at Ctrl-C - - """ - watch_seconds = 0.3 - wait_interval = 1 - expected_value = "1" - query = "SELECT {0!s}".format(expected_value) - expected_title = '> {0!s}'.format(query) - expected_results = 4 - ctrl_c_process = send_ctrl_c(wait_interval) - with db_connection().cursor() as cur: - results = list( - result for result in mycli.packages.special.iocommands.watch_query( - arg='{0!s} {1!s}'.format(watch_seconds, query), cur=cur - ) - ) - ctrl_c_process.join(1) - assert len(results) == expected_results - for result in results: - assert result[0] == expected_title - assert result[2][0] == expected_value - - -@dbtest -@patch('click.clear') -def test_watch_query_clear(clear_mock): - """Test that the screen is cleared with the -c flag of `watch` command - before execute the query.""" - with db_connection().cursor() as cur: - watch_gen = mycli.packages.special.iocommands.watch_query( - arg='0.1 -c select 1;', cur=cur - ) - assert not clear_mock.called - next(watch_gen) - assert clear_mock.called - clear_mock.reset_mock() - next(watch_gen) - assert clear_mock.called - clear_mock.reset_mock() - - -@dbtest -def test_watch_query_bad_arguments(): - """Test different incorrect combinations of arguments for `watch` - command.""" - watch_query = mycli.packages.special.iocommands.watch_query - with db_connection().cursor() as cur: - with pytest.raises(ProgrammingError): - next(watch_query('a select 1;', cur=cur)) - with pytest.raises(ProgrammingError): - next(watch_query('-a select 1;', cur=cur)) - with pytest.raises(ProgrammingError): - next(watch_query('1 -a select 1;', cur=cur)) - with pytest.raises(ProgrammingError): - next(watch_query('-c -a select 1;', cur=cur)) - - -@dbtest -@patch('click.clear') -def test_watch_query_interval_clear(clear_mock): - """Test `watch` command with interval and clear flag.""" - def test_asserts(gen): - clear_mock.reset_mock() - start = time() - next(gen) - assert clear_mock.called - next(gen) - exec_time = time() - start - assert exec_time > seconds and exec_time < (seconds + seconds) - - seconds = 1.0 - watch_query = mycli.packages.special.iocommands.watch_query - with db_connection().cursor() as cur: - test_asserts(watch_query('{0!s} -c select 1;'.format(seconds), - cur=cur)) - test_asserts(watch_query('-c {0!s} select 1;'.format(seconds), - cur=cur)) - - -def test_split_sql_by_delimiter(): - for delimiter_str in (';', '$', '😀'): - mycli.packages.special.set_delimiter(delimiter_str) - sql_input = "select 1{} select \ufffc2".format(delimiter_str) - queries = ( - "select 1", - "select \ufffc2" - ) - for query, parsed_query in zip( - queries, mycli.packages.special.split_queries(sql_input)): - assert(query == parsed_query) - - -def test_switch_delimiter_within_query(): - mycli.packages.special.set_delimiter(';') - sql_input = "select 1; delimiter $$ select 2 $$ select 3 $$" - queries = ( - "select 1", - "delimiter $$ select 2 $$ select 3 $$", - "select 2", - "select 3" - ) - for query, parsed_query in zip( - queries, - mycli.packages.special.split_queries(sql_input)): - assert(query == parsed_query) - - -def test_set_delimiter(): - - for delim in ('foo', 'bar'): - mycli.packages.special.set_delimiter(delim) - assert mycli.packages.special.get_current_delimiter() == delim - - -def teardown_function(): - mycli.packages.special.set_delimiter(';') diff --git a/test/test_sqlexecute.py b/test/test_sqlexecute.py deleted file mode 100644 index ca186bcb..00000000 --- a/test/test_sqlexecute.py +++ /dev/null @@ -1,304 +0,0 @@ -import os - -import pytest -import pymysql - -from mycli.sqlexecute import ServerInfo, ServerSpecies -from .utils import run, dbtest, set_expanded_output, is_expanded_output - - -def assert_result_equal(result, title=None, rows=None, headers=None, - status=None, auto_status=True, assert_contains=False): - """Assert that an sqlexecute.run() result matches the expected values.""" - if status is None and auto_status and rows: - status = '{} row{} in set'.format( - len(rows), 's' if len(rows) > 1 else '') - fields = {'title': title, 'rows': rows, 'headers': headers, - 'status': status} - - if assert_contains: - # Do a loose match on the results using the *in* operator. - for key, field in fields.items(): - if field: - assert field in result[0][key] - else: - # Do an exact match on the fields. - assert result == [fields] - - -@dbtest -def test_conn(executor): - run(executor, '''create table test(a text)''') - run(executor, '''insert into test values('abc')''') - results = run(executor, '''select * from test''') - - assert_result_equal(results, headers=['a'], rows=[('abc',)]) - - -@dbtest -def test_bools(executor): - run(executor, '''create table test(a boolean)''') - run(executor, '''insert into test values(True)''') - results = run(executor, '''select * from test''') - - assert_result_equal(results, headers=['a'], rows=[(1,)]) - - -@dbtest -def test_binary(executor): - run(executor, '''create table bt(geom linestring NOT NULL)''') - run(executor, "INSERT INTO bt VALUES " - "(ST_GeomFromText('LINESTRING(116.37604 39.73979,116.375 39.73965)'));") - results = run(executor, '''select * from bt''') - - geom = (b'\x00\x00\x00\x00\x01\x02\x00\x00\x00\x02\x00\x00\x009\x7f\x13\n' - b'\x11\x18]@4\xf4Op\xb1\xdeC@\x00\x00\x00\x00\x00\x18]@B>\xe8\xd9' - b'\xac\xdeC@') - - assert_result_equal(results, headers=['geom'], rows=[(geom,)]) - - -@dbtest -def test_table_and_columns_query(executor): - run(executor, "create table a(x text, y text)") - run(executor, "create table b(z text)") - - assert set(executor.tables()) == set([('a',), ('b',)]) - assert set(executor.table_columns()) == set( - [('a', 'x'), ('a', 'y'), ('b', 'z')]) - - -@dbtest -def test_database_list(executor): - databases = executor.databases() - assert 'mycli_test_db' in databases - - -@dbtest -def test_invalid_syntax(executor): - with pytest.raises(pymysql.ProgrammingError) as excinfo: - run(executor, 'invalid syntax!') - assert 'You have an error in your SQL syntax;' in str(excinfo.value) - - -@dbtest -def test_invalid_column_name(executor): - with pytest.raises(pymysql.err.OperationalError) as excinfo: - run(executor, 'select invalid command') - assert "Unknown column 'invalid' in 'field list'" in str(excinfo.value) - - -@dbtest -def test_unicode_support_in_output(executor): - run(executor, "create table unicodechars(t text)") - run(executor, u"insert into unicodechars (t) values ('é')") - - # See issue #24, this raises an exception without proper handling - results = run(executor, u"select * from unicodechars") - assert_result_equal(results, headers=['t'], rows=[(u'é',)]) - - -@dbtest -def test_multiple_queries_same_line(executor): - results = run(executor, "select 'foo'; select 'bar'") - - expected = [{'title': None, 'headers': ['foo'], 'rows': [('foo',)], - 'status': '1 row in set'}, - {'title': None, 'headers': ['bar'], 'rows': [('bar',)], - 'status': '1 row in set'}] - assert expected == results - - -@dbtest -def test_multiple_queries_same_line_syntaxerror(executor): - with pytest.raises(pymysql.ProgrammingError) as excinfo: - run(executor, "select 'foo'; invalid syntax") - assert 'You have an error in your SQL syntax;' in str(excinfo.value) - - -@dbtest -@pytest.mark.skipif(os.name == "nt", reason="Bug: fails on Windows, needs fixing, singleton of FQ not working right") -def test_favorite_query(executor): - set_expanded_output(False) - run(executor, "create table test(a text)") - run(executor, "insert into test values('abc')") - run(executor, "insert into test values('def')") - - results = run(executor, "\\fs test-a select * from test where a like 'a%'") - assert_result_equal(results, status='Saved.') - - results = run(executor, "\\f test-a") - assert_result_equal(results, - title="> select * from test where a like 'a%'", - headers=['a'], rows=[('abc',)], auto_status=False) - - results = run(executor, "\\fd test-a") - assert_result_equal(results, status='test-a: Deleted') - - -@dbtest -@pytest.mark.skipif(os.name == "nt", reason="Bug: fails on Windows, needs fixing, singleton of FQ not working right") -def test_favorite_query_multiple_statement(executor): - set_expanded_output(False) - run(executor, "create table test(a text)") - run(executor, "insert into test values('abc')") - run(executor, "insert into test values('def')") - - results = run(executor, - "\\fs test-ad select * from test where a like 'a%'; " - "select * from test where a like 'd%'") - assert_result_equal(results, status='Saved.') - - results = run(executor, "\\f test-ad") - expected = [{'title': "> select * from test where a like 'a%'", - 'headers': ['a'], 'rows': [('abc',)], 'status': None}, - {'title': "> select * from test where a like 'd%'", - 'headers': ['a'], 'rows': [('def',)], 'status': None}] - assert expected == results - - results = run(executor, "\\fd test-ad") - assert_result_equal(results, status='test-ad: Deleted') - - -@dbtest -@pytest.mark.skipif(os.name == "nt", reason="Bug: fails on Windows, needs fixing, singleton of FQ not working right") -def test_favorite_query_expanded_output(executor): - set_expanded_output(False) - run(executor, '''create table test(a text)''') - run(executor, '''insert into test values('abc')''') - - results = run(executor, "\\fs test-ae select * from test") - assert_result_equal(results, status='Saved.') - - results = run(executor, "\\f test-ae \\G") - assert is_expanded_output() is True - assert_result_equal(results, title='> select * from test', - headers=['a'], rows=[('abc',)], auto_status=False) - - set_expanded_output(False) - - results = run(executor, "\\fd test-ae") - assert_result_equal(results, status='test-ae: Deleted') - - -@dbtest -def test_special_command(executor): - results = run(executor, '\\?') - assert_result_equal(results, rows=('quit', '\\q', 'Quit.'), - headers='Command', assert_contains=True, - auto_status=False) - - -@dbtest -def test_cd_command_without_a_folder_name(executor): - results = run(executor, 'system cd') - assert_result_equal(results, status='No folder name was provided.') - - -@dbtest -def test_system_command_not_found(executor): - results = run(executor, 'system xyz') - if os.name=='nt': - assert_result_equal(results, status='OSError: The system cannot find the file specified', - assert_contains=True) - else: - assert_result_equal(results, status='OSError: No such file or directory', - assert_contains=True) - - -@dbtest -def test_system_command_output(executor): - eol = os.linesep - test_dir = os.path.abspath(os.path.dirname(__file__)) - test_file_path = os.path.join(test_dir, 'test.txt') - results = run(executor, 'system cat {0}'.format(test_file_path)) - assert_result_equal(results, status=f'mycli rocks!{eol}') - - -@dbtest -def test_cd_command_current_dir(executor): - test_path = os.path.abspath(os.path.dirname(__file__)) - run(executor, 'system cd {0}'.format(test_path)) - assert os.getcwd() == test_path - - -@dbtest -def test_unicode_support(executor): - results = run(executor, u"SELECT '日本語' AS japanese;") - assert_result_equal(results, headers=['japanese'], rows=[(u'日本語',)]) - - -@dbtest -def test_timestamp_null(executor): - run(executor, '''create table ts_null(a timestamp null)''') - run(executor, '''insert into ts_null values(null)''') - results = run(executor, '''select * from ts_null''') - assert_result_equal(results, headers=['a'], - rows=[(None,)]) - - -@dbtest -def test_datetime_null(executor): - run(executor, '''create table dt_null(a datetime null)''') - run(executor, '''insert into dt_null values(null)''') - results = run(executor, '''select * from dt_null''') - assert_result_equal(results, headers=['a'], - rows=[(None,)]) - - -@dbtest -def test_date_null(executor): - run(executor, '''create table date_null(a date null)''') - run(executor, '''insert into date_null values(null)''') - results = run(executor, '''select * from date_null''') - assert_result_equal(results, headers=['a'], rows=[(None,)]) - - -@dbtest -def test_time_null(executor): - run(executor, '''create table time_null(a time null)''') - run(executor, '''insert into time_null values(null)''') - results = run(executor, '''select * from time_null''') - assert_result_equal(results, headers=['a'], rows=[(None,)]) - - -@dbtest -def test_multiple_results(executor): - query = '''CREATE PROCEDURE dmtest() - BEGIN - SELECT 1; - SELECT 2; - END''' - executor.conn.cursor().execute(query) - - results = run(executor, 'call dmtest;') - expected = [ - {'title': None, 'rows': [(1,)], 'headers': ['1'], - 'status': '1 row in set'}, - {'title': None, 'rows': [(2,)], 'headers': ['2'], - 'status': '1 row in set'} - ] - assert results == expected - - -@pytest.mark.parametrize( - 'version_string, species, parsed_version_string, version', - ( - ('5.7.25-TiDB-v6.1.0','TiDB', '6.1.0', 60100), - ('8.0.11-TiDB-v7.2.0-alpha-69-g96e9e68daa', 'TiDB', '7.2.0', 70200), - ('5.7.32-35', 'Percona', '5.7.32', 50732), - ('5.7.32-0ubuntu0.18.04.1', 'MySQL', '5.7.32', 50732), - ('10.5.8-MariaDB-1:10.5.8+maria~focal', 'MariaDB', '10.5.8', 100508), - ('5.5.5-10.5.8-MariaDB-1:10.5.8+maria~focal', 'MariaDB', '10.5.8', 100508), - ('5.0.16-pro-nt-log', 'MySQL', '5.0.16', 50016), - ('5.1.5a-alpha', 'MySQL', '5.1.5', 50105), - ('unexpected version string', None, '', 0), - ('', None, '', 0), - (None, None, '', 0), - ) -) -def test_version_parsing(version_string, species, parsed_version_string, version): - server_info = ServerInfo.from_version_string(version_string) - assert (server_info.species and server_info.species.name) == species or ServerSpecies.Unknown - assert server_info.version_str == parsed_version_string - assert server_info.version == version diff --git a/test/test_tabular_output.py b/test/test_tabular_output.py deleted file mode 100644 index bdc1dbf0..00000000 --- a/test/test_tabular_output.py +++ /dev/null @@ -1,118 +0,0 @@ -"""Test the sql output adapter.""" - -from textwrap import dedent - -from mycli.packages.tabular_output import sql_format -from cli_helpers.tabular_output import TabularOutputFormatter - -from .utils import USER, PASSWORD, HOST, PORT, dbtest - -import pytest -from mycli.main import MyCli - -from pymysql.constants import FIELD_TYPE - - -@pytest.fixture -def mycli(): - cli = MyCli() - cli.connect(None, USER, PASSWORD, HOST, PORT, None, init_command=None) - return cli - - -@dbtest -def test_sql_output(mycli): - """Test the sql output adapter.""" - headers = ['letters', 'number', 'optional', 'float', 'binary'] - - class FakeCursor(object): - def __init__(self): - self.data = [ - ('abc', 1, None, 10.0, b'\xAA'), - ('d', 456, '1', 0.5, b'\xAA\xBB') - ] - self.description = [ - (None, FIELD_TYPE.VARCHAR), - (None, FIELD_TYPE.LONG), - (None, FIELD_TYPE.LONG), - (None, FIELD_TYPE.FLOAT), - (None, FIELD_TYPE.BLOB) - ] - - def __iter__(self): - return self - - def __next__(self): - if self.data: - return self.data.pop(0) - else: - raise StopIteration() - - def description(self): - return self.description - - # Test sql-update output format - assert list(mycli.change_table_format("sql-update")) == \ - [(None, None, None, 'Changed table format to sql-update')] - mycli.formatter.query = "" - output = mycli.format_output(None, FakeCursor(), headers) - actual = "\n".join(output) - assert actual == dedent('''\ - UPDATE `DUAL` SET - `number` = 1 - , `optional` = NULL - , `float` = 10.0e0 - , `binary` = X'aa' - WHERE `letters` = 'abc'; - UPDATE `DUAL` SET - `number` = 456 - , `optional` = '1' - , `float` = 0.5e0 - , `binary` = X'aabb' - WHERE `letters` = 'd';''') - # Test sql-update-2 output format - assert list(mycli.change_table_format("sql-update-2")) == \ - [(None, None, None, 'Changed table format to sql-update-2')] - mycli.formatter.query = "" - output = mycli.format_output(None, FakeCursor(), headers) - assert "\n".join(output) == dedent('''\ - UPDATE `DUAL` SET - `optional` = NULL - , `float` = 10.0e0 - , `binary` = X'aa' - WHERE `letters` = 'abc' AND `number` = 1; - UPDATE `DUAL` SET - `optional` = '1' - , `float` = 0.5e0 - , `binary` = X'aabb' - WHERE `letters` = 'd' AND `number` = 456;''') - # Test sql-insert output format (without table name) - assert list(mycli.change_table_format("sql-insert")) == \ - [(None, None, None, 'Changed table format to sql-insert')] - mycli.formatter.query = "" - output = mycli.format_output(None, FakeCursor(), headers) - assert "\n".join(output) == dedent('''\ - INSERT INTO `DUAL` (`letters`, `number`, `optional`, `float`, `binary`) VALUES - ('abc', 1, NULL, 10.0e0, X'aa') - , ('d', 456, '1', 0.5e0, X'aabb') - ;''') - # Test sql-insert output format (with table name) - assert list(mycli.change_table_format("sql-insert")) == \ - [(None, None, None, 'Changed table format to sql-insert')] - mycli.formatter.query = "SELECT * FROM `table`" - output = mycli.format_output(None, FakeCursor(), headers) - assert "\n".join(output) == dedent('''\ - INSERT INTO table (`letters`, `number`, `optional`, `float`, `binary`) VALUES - ('abc', 1, NULL, 10.0e0, X'aa') - , ('d', 456, '1', 0.5e0, X'aabb') - ;''') - # Test sql-insert output format (with database + table name) - assert list(mycli.change_table_format("sql-insert")) == \ - [(None, None, None, 'Changed table format to sql-insert')] - mycli.formatter.query = "SELECT * FROM `database`.`table`" - output = mycli.format_output(None, FakeCursor(), headers) - assert "\n".join(output) == dedent('''\ - INSERT INTO database.table (`letters`, `number`, `optional`, `float`, `binary`) VALUES - ('abc', 1, NULL, 10.0e0, X'aa') - , ('d', 456, '1', 0.5e0, X'aabb') - ;''') diff --git a/test/utils.py b/test/utils.py index ab122486..5bda3d3d 100644 --- a/test/utils.py +++ b/test/utils.py @@ -1,28 +1,255 @@ +# type: ignore + +from collections.abc import Iterator +import multiprocessing import os -import time -import signal import platform -import multiprocessing +import signal +import time +from types import SimpleNamespace +from typing import Any, Callable, Literal, cast +from packaging.version import Version +from prompt_toolkit.formatted_text import ( + ANSI, +) +import pygments import pymysql import pytest +from mycli import main +from mycli.constants import ( + DEFAULT_CHARSET, + DEFAULT_HOST, + DEFAULT_PORT, + DEFAULT_USER, + TEST_DATABASE, +) from mycli.main import special +from mycli.packages.sqlresult import SQLResult + +DATABASE = TEST_DATABASE +PASSWORD = os.getenv("PYTEST_PASSWORD") +USER = os.getenv("PYTEST_USER", DEFAULT_USER) +HOST = os.getenv("PYTEST_HOST", DEFAULT_HOST) +PORT = int(os.getenv("PYTEST_PORT", DEFAULT_PORT)) +CHARACTER_SET = os.getenv("PYTEST_CHARSET", DEFAULT_CHARSET) +SSH_USER = os.getenv("PYTEST_SSH_USER", None) +SSH_HOST = os.getenv("PYTEST_SSH_HOST", None) +SSH_PORT = int(os.getenv("PYTEST_SSH_PORT", "22")) +TEMPFILE_PREFIX = 'mycli_test_suite_' + +PYGMENTS_VERSION = Version(pygments.__version__) + + +def pygments_below(version: str) -> bool: + return PYGMENTS_VERSION < Version(version) + + +class DummyLogger: + def __init__(self) -> None: + self.debug_calls: list[tuple[tuple[Any, ...], dict[str, Any]]] = [] + self.error_calls: list[tuple[tuple[Any, ...], dict[str, Any]]] = [] + self.warning_calls: list[tuple[tuple[Any, ...], dict[str, Any]]] = [] + + def debug(self, *args: Any, **kwargs: Any) -> None: + self.debug_calls.append((args, kwargs)) + + def error(self, *args: Any, **kwargs: Any) -> None: + self.error_calls.append((args, kwargs)) + + def warning(self, *args: Any, **kwargs: Any) -> None: + self.warning_calls.append((args, kwargs)) + + +class DummyFormatter: + def __init__(self, format_name: str = 'ascii') -> None: + self.format_name = format_name + self.query = '' + self.supported_formats = ['ascii', 'csv', 'tsv', 'vertical'] + self._output_formats = { + 'ascii': SimpleNamespace(formatter_args={'missing_value': main.DEFAULT_MISSING_VALUE}), + 'csv': SimpleNamespace(formatter_args={'missing_value': main.DEFAULT_MISSING_VALUE}), + 'tsv': SimpleNamespace(formatter_args={'missing_value': main.DEFAULT_MISSING_VALUE}), + 'vertical': SimpleNamespace(formatter_args={'missing_value': main.DEFAULT_MISSING_VALUE}), + } + self.calls: list[tuple[tuple[Any, ...], dict[str, Any]]] = [] + + def format_output(self, rows: Any, header: Any, format_name: str | None = None, **kwargs: Any) -> list[str] | str: + self.calls.append(((rows, header, format_name), kwargs)) + if format_name == 'vertical': + return ['vertical output'] + return ['plain output'] + + +class ReusableLock: + def __init__(self, on_enter: Callable[[], Any] | None = None) -> None: + self.on_enter = on_enter -PASSWORD = os.getenv('PYTEST_PASSWORD') -USER = os.getenv('PYTEST_USER', 'root') -HOST = os.getenv('PYTEST_HOST', 'localhost') -PORT = int(os.getenv('PYTEST_PORT', 3306)) -CHARSET = os.getenv('PYTEST_CHARSET', 'utf8') -SSH_USER = os.getenv('PYTEST_SSH_USER', None) -SSH_HOST = os.getenv('PYTEST_SSH_HOST', None) -SSH_PORT = os.getenv('PYTEST_SSH_PORT', 22) + def __enter__(self) -> 'ReusableLock': + if self.on_enter is not None: + self.on_enter() + return self + + def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> Literal[False]: + return False + + +class FakeCursorBase: + def __init__( + self, + rows: list[tuple[Any, ...]] | None = None, + rowcount: int = 0, + description: list[tuple[Any, ...]] | None = None, + warning_count: int = 0, + ) -> None: + self._rows = list(rows or []) + self.rowcount = rowcount + self.description = description or [] + self.warning_count = warning_count + + def __iter__(self) -> Iterator[tuple[Any, ...]]: + return iter(self._rows) + + +class RecordingSQLExecute: + calls: list[dict[str, Any]] = [] + side_effects: list[Any] = [] + + def __init__(self, **kwargs: Any) -> None: + type(self).calls.append(dict(kwargs)) + if type(self).side_effects: + effect = type(self).side_effects.pop(0) + if isinstance(effect, BaseException): + raise effect + if callable(effect): + effect(kwargs) + self.kwargs = kwargs + self.dbname = kwargs.get('database') + self.user = kwargs.get('user') + self.conn = kwargs.get('conn') + self.sandbox_mode = False + + +def make_bare_mycli() -> Any: + cli = object.__new__(main.MyCli) + cli.logger = cast(Any, DummyLogger()) + cli.main_formatter = DummyFormatter() + cli.redirect_formatter = DummyFormatter() + cli.helpers_style = 'helpers-style' + cli.helpers_warnings_style = 'helpers-warnings-style' + cli.ptoolkit_style = cast(Any, 'pt-style') + cli.syntax_style = 'native' + cli.cli_style = {} + cli.null_string = '' + cli.numeric_alignment = 'right' + cli.binary_display = None + cli.show_warnings = False + cli.query_history = [] + cli.toolbar_error_message = None + cli.prompt_session = None + cli.last_prompt_message = ANSI('') + cli.last_custom_toolbar_message = ANSI('') + cli.prompt_lines = 0 + cli.prompt_format = main.MyCli.default_prompt + cli.multiline_continuation_char = '>' + cli.toolbar_format = 'default' + cli.destructive_warning = False + cli.destructive_keywords = ['drop'] + cli.keepalive_ticks = None + cli._keepalive_counter = 0 + cli.verbosity = -1 + cli.smart_completion = False + cli.key_bindings = 'emacs' + cli.auto_vertical_output = False + cli.wider_completion_menu = False + cli.explicit_pager = False + cli._completer_lock = cast(Any, ReusableLock()) + cli.prefetch_schemas_mode = 'never' + cli.prefetch_schemas_list = [] + cli.schema_prefetcher = cast( + Any, + SimpleNamespace( + stop=lambda: None, + clear_loaded=lambda: None, + start_configured=lambda: None, + is_prefetching=lambda: False, + prefetch_schema_now=lambda schema: None, + ), + ) + cli.terminal_tab_title_format = '' + cli.terminal_window_title_format = '' + cli.multiplex_window_title_format = '' + cli.multiplex_pane_title_format = '' + cli.dsn_alias = None + cli.login_path = None + cli.login_path_as_host = False + cli.post_redirect_command = None + cli.logfile = None + cli.emacs_ttimeoutlen = 1.0 + cli.vi_ttimeoutlen = 1.0 + cli.beep_after_seconds = 0.0 + cli.config = {'main': {'history_file': '~/.mycli-history-testing'}} + cli.output = lambda *args, **kwargs: None # type: ignore[assignment] + cli.echo = lambda *args, **kwargs: None # type: ignore[assignment] + cli.log_query = lambda *args, **kwargs: None # type: ignore[assignment] + cli.log_output = lambda *args, **kwargs: None # type: ignore[assignment] + cli.configure_pager = lambda: None # type: ignore[assignment] + cli.refresh_completions = lambda reset=False: [SQLResult(status='refresh')] # type: ignore[assignment] + cli.reconnect = lambda database='': False # type: ignore[assignment] + return cli + + +def make_dummy_mycli_class( + *, + config: dict[str, Any] | None = None, + my_cnf: dict[str, Any] | None = None, + config_without_package_defaults: dict[str, Any] | None = None, +) -> Any: + class DummyMyCli: + last_instance: Any = None + + def __init__(self, **kwargs: Any) -> None: + type(self).last_instance = self + self.init_kwargs = dict(kwargs) + self.config = config or {'main': {}, 'alias_dsn': {}} + self.my_cnf = my_cnf or {'client': {}, 'mysqld': {}} + self.config_without_package_defaults = config_without_package_defaults or {} + self.default_keepalive_ticks = 5 + self.ssl_mode = None + self.logger = DummyLogger() + self.main_formatter = SimpleNamespace(format_name=None) + self.destructive_warning = False + self.destructive_keywords = ['drop'] + self.dsn_alias = None + self.connect_calls: list[dict[str, Any]] = [] + self.run_query_calls: list[tuple[str, Any, bool]] = [] + self.run_cli_called = False + self.close_called = False + self.verbosity = 0 + + def connect(self, **kwargs: Any) -> None: + self.connect_calls.append(dict(kwargs)) + + def run_query(self, query: str, checkpoint: Any = None, new_line: bool = True) -> None: + self.run_query_calls.append((query, checkpoint, new_line)) + + def run_cli(self) -> None: + self.run_cli_called = True + + def close(self) -> None: + self.close_called = True + + return DummyMyCli + + +def call_click_entrypoint_direct(cli_args: main.CliArgs) -> None: + assert main.click_entrypoint.callback is not None + cast(Any, main.click_entrypoint.callback).__wrapped__(cli_args) def db_connection(dbname=None): - conn = pymysql.connect(user=USER, host=HOST, port=PORT, database=dbname, - password=PASSWORD, charset=CHARSET, - local_infile=False) + conn = pymysql.connect(user=USER, host=HOST, port=PORT, database=dbname, password=PASSWORD, charset=CHARACTER_SET, local_infile=False) conn.autocommit = True return conn @@ -30,33 +257,37 @@ def db_connection(dbname=None): try: db_connection() CAN_CONNECT_TO_DB = True -except: +except Exception: CAN_CONNECT_TO_DB = False -dbtest = pytest.mark.skipif( - not CAN_CONNECT_TO_DB, - reason="Need a mysql instance at localhost accessible by user 'root'") +dbtest = pytest.mark.skipif(not CAN_CONNECT_TO_DB, reason=f"Need a mysql instance at {DEFAULT_HOST} accessible by user '{DEFAULT_USER}'") def create_db(dbname): with db_connection().cursor() as cur: try: - cur.execute('''DROP DATABASE IF EXISTS mycli_test_db''') - cur.execute('''CREATE DATABASE mycli_test_db''') - except: + cur.execute(f"DROP DATABASE IF EXISTS {TEST_DATABASE}") + cur.execute(f"CREATE DATABASE {TEST_DATABASE}") + except Exception: pass def run(executor, sql, rows_as_list=True): """Return string output for the sql to be run.""" - result = [] + results = [] - for title, rows, headers, status in executor.run(sql): - rows = list(rows) if (rows_as_list and rows) else rows - result.append({'title': title, 'rows': rows, 'headers': headers, - 'status': status}) + for result in executor.run(sql): + rows = list(result.rows) if (rows_as_list and result.rows) else result.rows + results.append({ + "preamble": result.preamble, + "header": result.header, + "rows": rows, + "postamble": result.postamble, + "status": result.status, + "status_plain": result.status_plain, + }) - return result + return results def set_expanded_output(is_expanded): @@ -87,8 +318,6 @@ def send_ctrl_c(wait_seconds): Returns the `multiprocessing.Process` created. """ - ctrl_c_process = multiprocessing.Process( - target=send_ctrl_c_to_pid, args=(os.getpid(), wait_seconds) - ) + ctrl_c_process = multiprocessing.Process(target=send_ctrl_c_to_pid, args=(os.getpid(), wait_seconds)) ctrl_c_process.start() return ctrl_c_process diff --git a/tox.ini b/tox.ini deleted file mode 100644 index 612e8b7f..00000000 --- a/tox.ini +++ /dev/null @@ -1,15 +0,0 @@ -[tox] -envlist = py36, py37, py38 - -[testenv] -deps = pytest - mock - pexpect - behave - coverage -commands = python setup.py test -passenv = PYTEST_HOST - PYTEST_USER - PYTEST_PASSWORD - PYTEST_PORT - PYTEST_CHARSET