diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 5b5616f1..2fcdc68d 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -5,6 +5,8 @@ updates: directory: "/" schedule: interval: "weekly" + cooldown: + default-days: 7 # optional groups: actions: patterns: diff --git a/.github/workflows/array-api-tests-dask.yml b/.github/workflows/array-api-tests-dask.yml index 2ad98586..ef430d9c 100644 --- a/.github/workflows/array-api-tests-dask.yml +++ b/.github/workflows/array-api-tests-dask.yml @@ -7,7 +7,6 @@ jobs: uses: ./.github/workflows/array-api-tests.yml with: package-name: dask - package-version: '>= 2024.9.0' module-name: dask.array extra-requires: numpy # Dask is substantially slower then other libraries on unit tests. @@ -15,4 +14,7 @@ jobs: # workflow is barely more than a smoke test, and one should expect extreme # flakiness. Before changes to dask-xfails.txt or dask-skips.txt, please run # the full test suite with at least 200 examples. - pytest-extra-args: --max-examples=5 + pytest-extra-args: --max-examples=200 -n 4 + python-versions: '[''3.10'', ''3.13'']' + extra-env-vars: | + ARRAY_API_TESTS_XFAIL_MARK=skip diff --git a/.github/workflows/array-api-tests-numpy-1-21.yml b/.github/workflows/array-api-tests-numpy-1-21.yml deleted file mode 100644 index 2d81c3cd..00000000 --- a/.github/workflows/array-api-tests-numpy-1-21.yml +++ /dev/null @@ -1,11 +0,0 @@ -name: Array API Tests (NumPy 1.21) - -on: [push, pull_request] - -jobs: - array-api-tests-numpy-1-21: - uses: ./.github/workflows/array-api-tests.yml - with: - package-name: numpy - package-version: '== 1.21.*' - xfails-file-extra: '-1-21' diff --git a/.github/workflows/array-api-tests-numpy-1-22.yml b/.github/workflows/array-api-tests-numpy-1-22.yml new file mode 100644 index 00000000..83d4cf1d --- /dev/null +++ b/.github/workflows/array-api-tests-numpy-1-22.yml @@ -0,0 +1,15 @@ +name: Array API Tests (NumPy 1.22) + +on: [push, pull_request] + +jobs: + array-api-tests-numpy-1-22: + uses: ./.github/workflows/array-api-tests.yml + with: + package-name: numpy + package-version: '== 1.22.*' + xfails-file-extra: '-1-22' + python-versions: '[''3.10'']' + pytest-extra-args: -n 4 + extra-env-vars: | + ARRAY_API_TESTS_XFAIL_MARK=skip diff --git a/.github/workflows/array-api-tests-numpy-1-26.yml b/.github/workflows/array-api-tests-numpy-1-26.yml index 660935f0..13124644 100644 --- a/.github/workflows/array-api-tests-numpy-1-26.yml +++ b/.github/workflows/array-api-tests-numpy-1-26.yml @@ -9,3 +9,7 @@ jobs: package-name: numpy package-version: '== 1.26.*' xfails-file-extra: '-1-26' + python-versions: '[''3.10'', ''3.12'']' + pytest-extra-args: -n 4 + extra-env-vars: | + ARRAY_API_TESTS_XFAIL_MARK=skip diff --git a/.github/workflows/array-api-tests-numpy-dev.yml b/.github/workflows/array-api-tests-numpy-dev.yml index eef4269d..7a521360 100644 --- a/.github/workflows/array-api-tests-numpy-dev.yml +++ b/.github/workflows/array-api-tests-numpy-dev.yml @@ -9,3 +9,7 @@ jobs: package-name: numpy extra-requires: '--pre --extra-index https://pypi.anaconda.org/scientific-python-nightly-wheels/simple' xfails-file-extra: '-dev' + python-versions: '[''3.11'', ''3.13'', ''3.14'']' + pytest-extra-args: -n 4 + extra-env-vars: | + ARRAY_API_TESTS_XFAIL_MARK=skip diff --git a/.github/workflows/array-api-tests-numpy-latest.yml b/.github/workflows/array-api-tests-numpy-latest.yml index 36984345..03e0e11e 100644 --- a/.github/workflows/array-api-tests-numpy-latest.yml +++ b/.github/workflows/array-api-tests-numpy-latest.yml @@ -1,4 +1,4 @@ -name: Array API Tests (NumPy Latest) +name: Array API Tests (NumPy latest) on: [push, pull_request] @@ -7,3 +7,7 @@ jobs: uses: ./.github/workflows/array-api-tests.yml with: package-name: numpy + python-versions: '[''3.10'', ''3.13'', ''3.14'']' + pytest-extra-args: -n 4 + extra-env-vars: | + ARRAY_API_TESTS_XFAIL_MARK=skip diff --git a/.github/workflows/array-api-tests-torch.yml b/.github/workflows/array-api-tests-torch.yml index 56ab81a3..d5cdfa72 100644 --- a/.github/workflows/array-api-tests-torch.yml +++ b/.github/workflows/array-api-tests-torch.yml @@ -1,4 +1,4 @@ -name: Array API Tests (PyTorch Latest) +name: Array API Tests (PyTorch CPU) on: [push, pull_request] @@ -7,5 +7,9 @@ jobs: uses: ./.github/workflows/array-api-tests.yml with: package-name: torch + extra-requires: '--index-url https://download.pytorch.org/whl/cpu' extra-env-vars: | ARRAY_API_TESTS_SKIP_DTYPES=uint16,uint32,uint64 + ARRAY_API_TESTS_XFAIL_MARK=skip + python-versions: '[''3.10'', ''3.13'', ''3.14'']' + pytest-extra-args: -n 4 diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index 6ace193a..67ae76f5 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -16,6 +16,10 @@ on: required: false type: string default: '>= 0' + python-versions: + required: true + type: string + description: JSON array of Python versions to test against. pytest-extra-args: required: false type: string @@ -30,53 +34,57 @@ on: extra-env-vars: required: false type: string - description: "Multiline string of environment variables to set for the test run." + description: Multiline string of environment variables to set for the test run. env: - PYTEST_ARGS: "--max-examples 200 -v -rxXfE --ci ${{ inputs.pytest-extra-args }} --hypothesis-disable-deadline --durations 10" + PYTEST_ARGS: "--max-examples 1000 -v -rxXfE ${{ inputs.pytest-extra-args }} --hypothesis-disable-deadline --durations 20" jobs: tests: runs-on: ubuntu-latest strategy: + fail-fast: false matrix: - # Min version of dask we need dropped support for Python 3.9 - # There is no numpy git tip for Python 3.9 or 3.10 - python-version: ${{ (inputs.package-name == 'dask' && fromJson('[''3.10'', ''3.11'', ''3.12'']')) || (inputs.package-name == 'numpy' && inputs.xfails-file-extra == '-dev' && fromJson('[''3.11'', ''3.12'']')) || fromJson('[''3.9'', ''3.10'', ''3.11'', ''3.12'']') }} + python-version: ${{ fromJson(inputs.python-versions) }} steps: - name: Checkout array-api-compat - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: path: array-api-compat + - name: Checkout array-api-tests - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: repository: data-apis/array-api-tests submodules: 'true' path: array-api-tests + - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: ${{ matrix.python-version }} + - name: Set Extra Environment Variables # Set additional environment variables if provided if: inputs.extra-env-vars run: | echo "${{ inputs.extra-env-vars }}" >> $GITHUB_ENV + - name: Install dependencies - # NumPy 1.21 doesn't support Python 3.11. There doesn't seem to be a way - # to put this in the numpy 1.21 config file. - if: "! ((matrix.python-version == '3.11' || matrix.python-version == '3.12') && inputs.package-name == 'numpy' && contains(inputs.package-version, '1.21'))" run: | python -m pip install --upgrade pip python -m pip install '${{ inputs.package-name }} ${{ inputs.package-version }}' ${{ inputs.extra-requires }} python -m pip install -r ${GITHUB_WORKSPACE}/array-api-tests/requirements.txt + python -m pip install pytest-xdist + + - name: Dump pip environment + run: pip freeze + - name: Run the array API testsuite (${{ inputs.package-name }}) - if: "! ((matrix.python-version == '3.11' || matrix.python-version == '3.12') && inputs.package-name == 'numpy' && contains(inputs.package-version, '1.21'))" env: ARRAY_API_TESTS_MODULE: array_api_compat.${{ inputs.module-name || inputs.package-name }} - ARRAY_API_TESTS_VERSION: 2024.12 + ARRAY_API_TESTS_VERSION: 2025.12 # This enables the NEP 50 type promotion behavior (without it a lot of # tests fail on bad scalar type promotion behavior) NPY_PROMOTION_STATE: weak diff --git a/.github/workflows/dependabot-auto-merge.yml b/.github/workflows/dependabot-auto-merge.yml index bd29e25b..48dfc680 100644 --- a/.github/workflows/dependabot-auto-merge.yml +++ b/.github/workflows/dependabot-auto-merge.yml @@ -13,7 +13,7 @@ jobs: steps: - name: Dependabot metadata id: metadata - uses: dependabot/fetch-metadata@v2 + uses: dependabot/fetch-metadata@v3 with: github-token: "${{ secrets.GITHUB_TOKEN }}" - name: Enable auto-merge for Dependabot PRs diff --git a/.github/workflows/docs-build.yml b/.github/workflows/docs-build.yml index 04c3aa66..b0c39822 100644 --- a/.github/workflows/docs-build.yml +++ b/.github/workflows/docs-build.yml @@ -6,17 +6,17 @@ jobs: docs-build: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 - - uses: actions/setup-python@v5 + - uses: actions/checkout@v6 + - uses: actions/setup-python@v6 - name: Install Dependencies run: | - python -m pip install -r docs/requirements.txt + python -m pip install .[docs] - name: Build Docs run: | cd docs make html - name: Upload Artifact - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v7 with: name: docs-build path: docs/_build/html diff --git a/.github/workflows/docs-deploy.yml b/.github/workflows/docs-deploy.yml index fc612588..f96b39ec 100644 --- a/.github/workflows/docs-deploy.yml +++ b/.github/workflows/docs-deploy.yml @@ -11,9 +11,9 @@ jobs: environment: name: docs-deploy steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Download Artifact - uses: dawidd6/action-download-artifact@v9 + uses: dawidd6/action-download-artifact@v20 with: workflow: docs-build.yml name: docs-build diff --git a/.github/workflows/publish-package.yml b/.github/workflows/publish-package.yml index 7733059d..826c5239 100644 --- a/.github/workflows/publish-package.yml +++ b/.github/workflows/publish-package.yml @@ -30,24 +30,25 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 with: fetch-depth: 0 - name: Set up Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: '3.x' - name: Install python-build and twine run: | - python -m pip install --upgrade pip setuptools + python -m pip install --upgrade pip "setuptools<=67" python -m pip install build twine python -m pip list - name: Build a wheel and a sdist run: | - PYTHONWARNINGS=error,default::DeprecationWarning python -m build . + #PYTHONWARNINGS=error,default::DeprecationWarning python -m build . + python -m build . - name: Verify the distribution run: twine check --strict dist/* @@ -59,7 +60,7 @@ jobs: run: python -m zipfile --list dist/array_api_compat-*.whl - name: Upload distribution artifact - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v7 with: name: dist-artifact path: dist @@ -80,7 +81,7 @@ jobs: steps: - name: Download distribution artifact - uses: actions/download-artifact@v4 + uses: actions/download-artifact@v8 with: name: dist-artifact path: dist @@ -94,19 +95,19 @@ jobs: # if: >- # (github.event_name == 'push' && startsWith(github.ref, 'refs/tags')) # || (github.event_name == 'workflow_dispatch' && github.event.inputs.publish == 'true') - # uses: pypa/gh-action-pypi-publish@v1.12.4 + # uses: pypa/gh-action-pypi-publish@v1.14.0 # with: # repository-url: https://test.pypi.org/legacy/ # print-hash: true - name: Publish distribution 📦 to PyPI if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags') - uses: pypa/gh-action-pypi-publish@v1.12.4 + uses: pypa/gh-action-pypi-publish@v1.14.0 with: print-hash: true - name: Create GitHub Release from a Tag - uses: softprops/action-gh-release@v2 + uses: softprops/action-gh-release@v3 if: startsWith(github.ref, 'refs/tags/') with: files: dist/* diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml index a9f0fd4b..6e838902 100644 --- a/.github/workflows/ruff.yml +++ b/.github/workflows/ruff.yml @@ -5,9 +5,9 @@ jobs: runs-on: ubuntu-latest continue-on-error: true steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Install Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: "3.11" - name: Install dependencies diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index fcd43367..c27283ed 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -4,43 +4,65 @@ jobs: tests: runs-on: ubuntu-latest strategy: + fail-fast: false matrix: - python-version: ['3.9', '3.10', '3.11', '3.12'] - numpy-version: ['1.21', '1.26', '2.0', 'dev'] - exclude: - - python-version: '3.11' - numpy-version: '1.21' - - python-version: '3.12' - numpy-version: '1.21' - fail-fast: true + include: + - numpy-version: '1.22' + python-version: '3.10' + - numpy-version: '1.26' + python-version: '3.10' + - numpy-version: '1.26' + python-version: '3.12' + - numpy-version: 'latest' + python-version: '3.10' + - numpy-version: 'latest' + python-version: '3.13' + - numpy-version: 'latest' + python-version: '3.14' + - numpy-version: 'dev' + python-version: '3.11' + - numpy-version: 'dev' + python-version: '3.13' + - numpy-version: 'dev' + python-version: '3.14' + steps: - - uses: actions/checkout@v4 - - uses: actions/setup-python@v5 + - uses: actions/checkout@v6 + - uses: actions/setup-python@v6 with: python-version: ${{ matrix.python-version }} - name: Install Dependencies run: | python -m pip install --upgrade pip + python -m pip install pytest + + # Don't `pip install .[dev]` as it would pull in the whole torch cuda stack + python -m pip install array-api-strict + python -m pip install torch --index-url https://download.pytorch.org/whl/cpu + if [ "${{ matrix.numpy-version }}" == "dev" ]; then - PIP_EXTRA='numpy --pre --extra-index-url https://pypi.anaconda.org/scientific-python-nightly-wheels/simple' - elif [ "${{ matrix.numpy-version }}" == "1.21" ]; then - PIP_EXTRA='numpy==1.21.*' + python -m pip install numpy --pre --extra-index-url https://pypi.anaconda.org/scientific-python-nightly-wheels/simple + python -m pip install dask[array] jax[cpu] + if ["${{ matrix.python-version }}" != "3.14]; then + python -m pip install sparse ndonnx + fi + elif [ "${{ matrix.numpy-version }}" == "1.22" ]; then + python -m pip install 'numpy==1.22.*' + elif [ "${{ matrix.numpy-version }}" == "1.26" ]; then + python -m pip install 'numpy==1.26.*' else - PIP_EXTRA='numpy==1.26.*' + python -m pip install numpy + python -m pip install dask[array] jax[cpu] + if ["${{ matrix.python-version }}" != "3.14]; then + python -m pip install sparse ndonnx + fi fi - if [ "${{ matrix.python-version }}" == "3.9" ]; then - sed -i '/^ndonnx/d' requirements-dev.txt - fi + - name: Dump pip environment + run: pip freeze - python -m pip install -r requirements-dev.txt $PIP_EXTRA + - name: Test it installs + run: python -m pip install . - name: Run Tests - run: | - if [[ "${{ matrix.numpy-version }}" == "1.21" || "${{ matrix.numpy-version }}" == "dev" ]]; then - PYTEST_EXTRA=(-k "numpy and not jax and not torch and not dask and not sparse") - fi - pytest -v "${PYTEST_EXTRA[@]}" - - # Make sure it installs - python -m pip install . + run: pytest -v diff --git a/array_api_compat/__init__.py b/array_api_compat/__init__.py index 60b37e97..454c7f8b 100644 --- a/array_api_compat/__init__.py +++ b/array_api_compat/__init__.py @@ -17,6 +17,6 @@ this implementation for the default when working with NumPy arrays. """ -__version__ = '1.12.dev0' +__version__ = '1.15.0.dev0' -from .common import * # noqa: F401, F403 +from .common import * # noqa: F403 diff --git a/array_api_compat/_internal.py b/array_api_compat/_internal.py index 170a1ff9..baa39ded 100644 --- a/array_api_compat/_internal.py +++ b/array_api_compat/_internal.py @@ -2,10 +2,17 @@ Internal helpers """ +import importlib +from collections.abc import Callable from functools import wraps from inspect import signature +from types import ModuleType +from typing import TypeVar -def get_xp(xp): +_T = TypeVar("_T") + + +def get_xp(xp: ModuleType) -> Callable[[Callable[..., _T]], Callable[..., _T]]: """ Decorator to automatically replace xp with the corresponding array module. @@ -22,14 +29,14 @@ def func(x, /, xp, kwarg=None): """ - def inner(f): + def inner(f: Callable[..., _T], /) -> Callable[..., _T]: @wraps(f) - def wrapped_f(*args, **kwargs): + def wrapped_f(*args: object, **kwargs: object) -> object: return f(*args, xp=xp, **kwargs) sig = signature(f) new_sig = sig.replace( - parameters=[sig.parameters[i] for i in sig.parameters if i != "xp"] + parameters=[par for i, par in sig.parameters.items() if i != "xp"] ) if wrapped_f.__doc__ is None: @@ -40,7 +47,31 @@ def wrapped_f(*args, **kwargs): specification for more details. """ - wrapped_f.__signature__ = new_sig - return wrapped_f + wrapped_f.__signature__ = new_sig # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] + return wrapped_f # type: ignore[return-value] # pyright: ignore[reportReturnType] return inner + + +def clone_module(mod_name: str, globals_: dict[str, object]) -> list[str]: + """Import everything from module, updating globals(). + Returns __all__. + """ + mod = importlib.import_module(mod_name) + # Neither of these two methods is sufficient by itself, + # depending on various idiosyncrasies of the libraries we're wrapping. + objs = {} + exec(f"from {mod.__name__} import *", objs) + + for n in dir(mod): + if not n.startswith("_") and hasattr(mod, n): + objs[n] = getattr(mod, n) + + globals_.update(objs) + return list(objs) + + +__all__ = ["get_xp", "clone_module"] + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/common/__init__.py b/array_api_compat/common/__init__.py index 91ab1c40..82360807 100644 --- a/array_api_compat/common/__init__.py +++ b/array_api_compat/common/__init__.py @@ -1 +1 @@ -from ._helpers import * # noqa: F403 +from ._helpers import * # noqa: F403 diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 98b8e425..3587ef16 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -4,142 +4,172 @@ from __future__ import annotations -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from typing import Optional, Sequence, Tuple, Union - from ._typing import ndarray, Device, Dtype - -from typing import NamedTuple import inspect +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any, NamedTuple, cast -from ._helpers import array_namespace, _check_device, device, is_torch_array, is_cupy_namespace +from ._helpers import _check_device, array_namespace +from ._helpers import device as _get_device +from ._helpers import is_cupy_namespace +from ._typing import Array, Device, DType, Namespace + +if TYPE_CHECKING: + # TODO: import from typing (requires Python >=3.13) + from typing_extensions import TypeIs # These functions are modified from the NumPy versions. -# Creation functions add the device keyword (which does nothing for NumPy) +# Creation functions add the device keyword (which does nothing for NumPy and Dask) + def arange( - start: Union[int, float], + start: float, /, - stop: Optional[Union[int, float]] = None, - step: Union[int, float] = 1, + stop: float | None = None, + step: float = 1, *, - xp, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, - **kwargs -) -> ndarray: + xp: Namespace, + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, +) -> Array: _check_device(xp, device) return xp.arange(start, stop=stop, step=step, dtype=dtype, **kwargs) + def empty( - shape: Union[int, Tuple[int, ...]], - xp, + shape: int | tuple[int, ...], + xp: Namespace, *, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, - **kwargs -) -> ndarray: + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, +) -> Array: _check_device(xp, device) return xp.empty(shape, dtype=dtype, **kwargs) + def empty_like( - x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, - **kwargs -) -> ndarray: + x: Array, + /, + xp: Namespace, + *, + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, +) -> Array: _check_device(xp, device) return xp.empty_like(x, dtype=dtype, **kwargs) + def eye( n_rows: int, - n_cols: Optional[int] = None, + n_cols: int | None = None, /, *, - xp, + xp: Namespace, k: int = 0, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, - **kwargs, -) -> ndarray: + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, +) -> Array: _check_device(xp, device) return xp.eye(n_rows, M=n_cols, k=k, dtype=dtype, **kwargs) + def full( - shape: Union[int, Tuple[int, ...]], - fill_value: Union[int, float], - xp, + shape: int | tuple[int, ...], + fill_value: complex, + xp: Namespace, *, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, - **kwargs, -) -> ndarray: + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, +) -> Array: _check_device(xp, device) return xp.full(shape, fill_value, dtype=dtype, **kwargs) + def full_like( - x: ndarray, + x: Array, /, - fill_value: Union[int, float], + fill_value: complex, *, - xp, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, - **kwargs, -) -> ndarray: + xp: Namespace, + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, +) -> Array: _check_device(xp, device) return xp.full_like(x, fill_value, dtype=dtype, **kwargs) + def linspace( - start: Union[int, float], - stop: Union[int, float], + start: float, + stop: float, /, num: int, *, - xp, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, + xp: Namespace, + dtype: DType | None = None, + device: Device | None = None, endpoint: bool = True, - **kwargs, -) -> ndarray: + **kwargs: object, +) -> Array: _check_device(xp, device) return xp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint, **kwargs) + def ones( - shape: Union[int, Tuple[int, ...]], - xp, + shape: int | tuple[int, ...], + xp: Namespace, *, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, - **kwargs, -) -> ndarray: + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, +) -> Array: _check_device(xp, device) return xp.ones(shape, dtype=dtype, **kwargs) + def ones_like( - x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, - **kwargs, -) -> ndarray: + x: Array, + /, + xp: Namespace, + *, + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, +) -> Array: _check_device(xp, device) return xp.ones_like(x, dtype=dtype, **kwargs) + def zeros( - shape: Union[int, Tuple[int, ...]], - xp, + shape: int | tuple[int, ...], + xp: Namespace, *, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, - **kwargs, -) -> ndarray: + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, +) -> Array: _check_device(xp, device) return xp.zeros(shape, dtype=dtype, **kwargs) + def zeros_like( - x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, - **kwargs, -) -> ndarray: + x: Array, + /, + xp: Namespace, + *, + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, +) -> Array: _check_device(xp, device) return xp.zeros_like(x, dtype=dtype, **kwargs) + # np.unique() is split into four functions in the array API: # unique_all, unique_counts, unique_inverse, and unique_values (this is done # to remove polymorphic return types). @@ -147,35 +177,37 @@ def zeros_like( # The functions here return namedtuples (np.unique() returns a normal # tuple). + # Note that these named tuples aren't actually part of the standard namespace, # but I don't see any issue with exporting the names here regardless. class UniqueAllResult(NamedTuple): - values: ndarray - indices: ndarray - inverse_indices: ndarray - counts: ndarray + values: Array + indices: Array + inverse_indices: Array + counts: Array class UniqueCountsResult(NamedTuple): - values: ndarray - counts: ndarray + values: Array + counts: Array class UniqueInverseResult(NamedTuple): - values: ndarray - inverse_indices: ndarray + values: Array + inverse_indices: Array -def _unique_kwargs(xp): +def _unique_kwargs(xp: Namespace) -> dict[str, bool]: # Older versions of NumPy and CuPy do not have equal_nan. Rather than # trying to parse version numbers, just check if equal_nan is in the # signature. s = inspect.signature(xp.unique) - if 'equal_nan' in s.parameters: - return {'equal_nan': False} + if "equal_nan" in s.parameters: + return {"equal_nan": False} return {} -def unique_all(x: ndarray, /, xp) -> UniqueAllResult: + +def unique_all(x: Array, /, xp: Namespace) -> UniqueAllResult: kwargs = _unique_kwargs(xp) values, indices, inverse_indices, counts = xp.unique( x, @@ -195,20 +227,16 @@ def unique_all(x: ndarray, /, xp) -> UniqueAllResult: ) -def unique_counts(x: ndarray, /, xp) -> UniqueCountsResult: +def unique_counts(x: Array, /, xp: Namespace) -> UniqueCountsResult: kwargs = _unique_kwargs(xp) res = xp.unique( - x, - return_counts=True, - return_index=False, - return_inverse=False, - **kwargs + x, return_counts=True, return_index=False, return_inverse=False, **kwargs ) return UniqueCountsResult(*res) -def unique_inverse(x: ndarray, /, xp) -> UniqueInverseResult: +def unique_inverse(x: Array, /, xp: Namespace) -> UniqueInverseResult: kwargs = _unique_kwargs(xp) values, inverse_indices = xp.unique( x, @@ -223,7 +251,7 @@ def unique_inverse(x: ndarray, /, xp) -> UniqueInverseResult: return UniqueInverseResult(values, inverse_indices) -def unique_values(x: ndarray, /, xp) -> ndarray: +def unique_values(x: Array, /, xp: Namespace) -> Array: kwargs = _unique_kwargs(xp) return xp.unique( x, @@ -233,51 +261,58 @@ def unique_values(x: ndarray, /, xp) -> ndarray: **kwargs, ) + # These functions have different keyword argument names + def std( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - correction: Union[int, float] = 0.0, # correction instead of ddof + axis: int | tuple[int, ...] | None = None, + correction: float = 0.0, # correction instead of ddof keepdims: bool = False, - **kwargs, -) -> ndarray: + **kwargs: object, +) -> Array: return xp.std(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs) + def var( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - correction: Union[int, float] = 0.0, # correction instead of ddof + axis: int | tuple[int, ...] | None = None, + correction: float = 0.0, # correction instead of ddof keepdims: bool = False, - **kwargs, -) -> ndarray: + **kwargs: object, +) -> Array: return xp.var(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs) + # cumulative_sum is renamed from cumsum, and adds the include_initial keyword # argument + def cumulative_sum( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, - axis: Optional[int] = None, - dtype: Optional[Dtype] = None, + axis: int | None = None, + dtype: DType | None = None, include_initial: bool = False, - **kwargs -) -> ndarray: + **kwargs: object, +) -> Array: wrapped_xp = array_namespace(x) # TODO: The standard is not clear about what should happen when x.ndim == 0. if axis is None: if x.ndim > 1: - raise ValueError("axis must be specified in cumulative_sum for more than one dimension") + raise ValueError( + "axis must be specified in cumulative_sum for more than one dimension" + ) axis = 0 res = xp.cumsum(x, axis=axis, dtype=dtype, **kwargs) @@ -287,27 +322,34 @@ def cumulative_sum( initial_shape = list(x.shape) initial_shape[axis] = 1 res = xp.concatenate( - [wrapped_xp.zeros(shape=initial_shape, dtype=res.dtype, device=device(res)), res], + [ + wrapped_xp.zeros( + shape=initial_shape, dtype=res.dtype, device=_get_device(res) + ), + res, + ], axis=axis, ) return res def cumulative_prod( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, - axis: Optional[int] = None, - dtype: Optional[Dtype] = None, + axis: int | None = None, + dtype: DType | None = None, include_initial: bool = False, - **kwargs -) -> ndarray: + **kwargs: object, +) -> Array: wrapped_xp = array_namespace(x) if axis is None: if x.ndim > 1: - raise ValueError("axis must be specified in cumulative_prod for more than one dimension") + raise ValueError( + "axis must be specified in cumulative_prod for more than one dimension" + ) axis = 0 res = xp.cumprod(x, axis=axis, dtype=dtype, **kwargs) @@ -317,25 +359,32 @@ def cumulative_prod( initial_shape = list(x.shape) initial_shape[axis] = 1 res = xp.concatenate( - [wrapped_xp.ones(shape=initial_shape, dtype=res.dtype, device=device(res)), res], + [ + wrapped_xp.ones( + shape=initial_shape, dtype=res.dtype, device=_get_device(res) + ), + res, + ], axis=axis, ) return res + # The min and max argument names in clip are different and not optional in numpy, and type # promotion behavior is different. def clip( - x: ndarray, + x: Array, /, - min: Optional[Union[int, float, ndarray]] = None, - max: Optional[Union[int, float, ndarray]] = None, + min: float | Array | None = None, + max: float | Array | None = None, *, - xp, + xp: Namespace, # TODO: np.clip has other ufunc kwargs - out: Optional[ndarray] = None, -) -> ndarray: - def _isscalar(a): - return isinstance(a, (int, float, type(None))) + out: Array | None = None, +) -> Array: + def _isscalar(a: object) -> TypeIs[float | None]: + return isinstance(a, int | float) or a is None + min_shape = () if _isscalar(min) else min.shape max_shape = () if _isscalar(max) else max.shape @@ -360,44 +409,51 @@ def _isscalar(a): # but an answer of 0 might be preferred. See # https://github.com/numpy/numpy/issues/24976 for more discussion on this issue. - # At least handle the case of Python integers correctly (see # https://github.com/numpy/numpy/pull/26892). - if type(min) is int and min <= wrapped_xp.iinfo(x.dtype).min: - min = None - if type(max) is int and max >= wrapped_xp.iinfo(x.dtype).max: - max = None + if wrapped_xp.isdtype(x.dtype, "integral"): + if type(min) is int and min <= wrapped_xp.iinfo(x.dtype).min: + min = None + if type(max) is int and max >= wrapped_xp.iinfo(x.dtype).max: + max = None + dev = _get_device(x) if out is None: - out = wrapped_xp.asarray(xp.broadcast_to(x, result_shape), - copy=True, device=device(x)) + out = wrapped_xp.empty(result_shape, dtype=x.dtype, device=dev) + assert out is not None # workaround for a type-narrowing issue in pyright + out[()] = x + if min is not None: - if is_torch_array(x) and x.dtype == xp.float64 and _isscalar(min): - # Avoid loss of precision due to torch defaulting to float32 - min = wrapped_xp.asarray(min, dtype=xp.float64) - a = xp.broadcast_to(wrapped_xp.asarray(min, device=device(x)), result_shape) + a = wrapped_xp.asarray(min, dtype=x.dtype, device=dev) + a = xp.broadcast_to(a, result_shape) ia = (out < a) | xp.isnan(a) - # torch requires an explicit cast here - out[ia] = wrapped_xp.astype(a[ia], out.dtype) + out[ia] = a[ia] + if max is not None: - if is_torch_array(x) and x.dtype == xp.float64 and _isscalar(max): - max = wrapped_xp.asarray(max, dtype=xp.float64) - b = xp.broadcast_to(wrapped_xp.asarray(max, device=device(x)), result_shape) + b = wrapped_xp.asarray(max, dtype=x.dtype, device=dev) + b = xp.broadcast_to(b, result_shape) ib = (out > b) | xp.isnan(b) - out[ib] = wrapped_xp.astype(b[ib], out.dtype) + out[ib] = b[ib] + # Return a scalar for 0-D return out[()] + # Unlike transpose(), the axes argument to permute_dims() is required. -def permute_dims(x: ndarray, /, axes: Tuple[int, ...], xp) -> ndarray: +def permute_dims(x: Array, /, axes: tuple[int, ...], xp: Namespace) -> Array: return xp.transpose(x, axes) + # np.reshape calls the keyword argument 'newshape' instead of 'shape' -def reshape(x: ndarray, - /, - shape: Tuple[int, ...], - xp, copy: Optional[bool] = None, - **kwargs) -> ndarray: +def reshape( + x: Array, + /, + shape: tuple[int, ...], + xp: Namespace, + *, + copy: bool | None = None, + **kwargs: object, +) -> Array: if copy is True: x = x.copy() elif copy is False: @@ -406,17 +462,24 @@ def reshape(x: ndarray, return y return xp.reshape(x, shape, **kwargs) + # The descending keyword is new in sort and argsort, and 'kind' replaced with # 'stable' def argsort( - x: ndarray, /, xp, *, axis: int = -1, descending: bool = False, stable: bool = True, - **kwargs, -) -> ndarray: + x: Array, + /, + xp: Namespace, + *, + axis: int = -1, + descending: bool = False, + stable: bool = True, + **kwargs: object, +) -> Array: # Note: this keyword argument is different, and the default is different. # We set it in kwargs like this because numpy.sort uses kind='quicksort' # as the default whereas cupy.sort uses kind=None. if stable: - kwargs['kind'] = "stable" + kwargs["kind"] = "stable" if not descending: res = xp.argsort(x, axis=axis, **kwargs) else: @@ -433,69 +496,66 @@ def argsort( res = max_i - res return res + def sort( - x: ndarray, /, xp, *, axis: int = -1, descending: bool = False, stable: bool = True, - **kwargs, -) -> ndarray: + x: Array, + /, + xp: Namespace, + *, + axis: int = -1, + descending: bool = False, + stable: bool = True, + **kwargs: object, +) -> Array: # Note: this keyword argument is different, and the default is different. # We set it in kwargs like this because numpy.sort uses kind='quicksort' # as the default whereas cupy.sort uses kind=None. if stable: - kwargs['kind'] = "stable" + kwargs["kind"] = "stable" res = xp.sort(x, axis=axis, **kwargs) if descending: res = xp.flip(res, axis=axis) return res + # nonzero should error for zero-dimensional arrays -def nonzero(x: ndarray, /, xp, **kwargs) -> Tuple[ndarray, ...]: +def nonzero(x: Array, /, xp: Namespace, **kwargs: object) -> tuple[Array, ...]: if x.ndim == 0: raise ValueError("nonzero() does not support zero-dimensional arrays") return xp.nonzero(x, **kwargs) -# ceil, floor, and trunc return integers for integer inputs - -def ceil(x: ndarray, /, xp, **kwargs) -> ndarray: - if xp.issubdtype(x.dtype, xp.integer): - return x - return xp.ceil(x, **kwargs) - -def floor(x: ndarray, /, xp, **kwargs) -> ndarray: - if xp.issubdtype(x.dtype, xp.integer): - return x - return xp.floor(x, **kwargs) - -def trunc(x: ndarray, /, xp, **kwargs) -> ndarray: - if xp.issubdtype(x.dtype, xp.integer): - return x - return xp.trunc(x, **kwargs) # linear algebra functions -def matmul(x1: ndarray, x2: ndarray, /, xp, **kwargs) -> ndarray: + +def matmul(x1: Array, x2: Array, /, xp: Namespace, **kwargs: object) -> Array: return xp.matmul(x1, x2, **kwargs) + # Unlike transpose, matrix_transpose only transposes the last two axes. -def matrix_transpose(x: ndarray, /, xp) -> ndarray: +def matrix_transpose(x: Array, /, xp: Namespace) -> Array: if x.ndim < 2: raise ValueError("x must be at least 2-dimensional for matrix_transpose") return xp.swapaxes(x, -1, -2) -def tensordot(x1: ndarray, - x2: ndarray, - /, - xp, - *, - axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2, - **kwargs, -) -> ndarray: + +def tensordot( + x1: Array, + x2: Array, + /, + xp: Namespace, + *, + axes: int | tuple[Sequence[int], Sequence[int]] = 2, + **kwargs: object, +) -> Array: return xp.tensordot(x1, x2, axes=axes, **kwargs) -def vecdot(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1) -> ndarray: + +def vecdot(x1: Array, x2: Array, /, xp: Namespace, *, axis: int = -1) -> Array: if x1.shape[axis] != x2.shape[axis]: raise ValueError("x1 and x2 must have the same size along the given axis") - if hasattr(xp, 'broadcast_tensors'): + if hasattr(xp, "broadcast_tensors"): _broadcast = xp.broadcast_tensors else: _broadcast = xp.broadcast_arrays @@ -507,11 +567,16 @@ def vecdot(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1) -> ndarray: res = xp.conj(x1_[..., None, :]) @ x2_[..., None] return res[..., 0, 0] + # isdtype is a new function in the 2022.12 array API specification. + def isdtype( - dtype: Dtype, kind: Union[Dtype, str, Tuple[Union[Dtype, str], ...]], xp, - *, _tuple=True, # Disallow nested tuples + dtype: DType, + kind: DType | str | tuple[DType | str, ...], + xp: Namespace, + *, + _tuple: bool = True, # Disallow nested tuples ) -> bool: """ Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``. @@ -524,21 +589,24 @@ def isdtype( for more details """ if isinstance(kind, tuple) and _tuple: - return any(isdtype(dtype, k, xp, _tuple=False) for k in kind) + return any( + isdtype(dtype, k, xp, _tuple=False) + for k in cast("tuple[DType | str, ...]", kind) + ) elif isinstance(kind, str): - if kind == 'bool': + if kind == "bool": return dtype == xp.bool_ - elif kind == 'signed integer': + elif kind == "signed integer": return xp.issubdtype(dtype, xp.signedinteger) - elif kind == 'unsigned integer': + elif kind == "unsigned integer": return xp.issubdtype(dtype, xp.unsignedinteger) - elif kind == 'integral': + elif kind == "integral": return xp.issubdtype(dtype, xp.integer) - elif kind == 'real floating': + elif kind == "real floating": return xp.issubdtype(dtype, xp.floating) - elif kind == 'complex floating': + elif kind == "complex floating": return xp.issubdtype(dtype, xp.complexfloating) - elif kind == 'numeric': + elif kind == "numeric": return xp.issubdtype(dtype, xp.number) else: raise ValueError(f"Unrecognized data type kind: {kind!r}") @@ -549,32 +617,86 @@ def isdtype( # array_api_strict implementation will be very strict. return dtype == kind + # unstack is a new function in the 2023.12 array API standard -def unstack(x: ndarray, /, xp, *, axis: int = 0) -> Tuple[ndarray, ...]: +def unstack(x: Array, /, xp: Namespace, *, axis: int = 0) -> tuple[Array, ...]: if x.ndim == 0: raise ValueError("Input array must be at least 1-d.") return tuple(xp.moveaxis(x, axis, 0)) + # numpy 1.26 does not use the standard definition for sign on complex numbers -def sign(x: ndarray, /, xp, **kwargs) -> ndarray: - if isdtype(x.dtype, 'complex floating', xp=xp): - out = (x/xp.abs(x, **kwargs))[...] + +def sign(x: Array, /, xp: Namespace, **kwargs: object) -> Array: + if isdtype(x.dtype, "complex floating", xp=xp): + out = (x / xp.abs(x, **kwargs))[...] # sign(0) = 0 but the above formula would give nan - out[x == 0+0j] = 0+0j + out[x == 0j] = 0j else: out = xp.sign(x, **kwargs) # CuPy sign() does not propagate nans. See # https://github.com/data-apis/array-api-compat/issues/136 - if is_cupy_namespace(xp) and isdtype(x.dtype, 'real floating', xp=xp): + if is_cupy_namespace(xp) and isdtype(x.dtype, "real floating", xp=xp): out[xp.isnan(x)] = xp.nan return out[()] -__all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like', - 'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like', - 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult', - 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values', - 'std', 'var', 'cumulative_sum', 'cumulative_prod','clip', 'permute_dims', - 'reshape', 'argsort', 'sort', 'nonzero', 'ceil', 'floor', 'trunc', - 'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype', - 'unstack', 'sign'] + +def finfo(type_: DType | Array, /, xp: Namespace) -> Any: + # It is surprisingly difficult to recognize a dtype apart from an array. + # np.int64 is not the same as np.asarray(1).dtype! + try: + return xp.finfo(type_) + except (ValueError, TypeError): + return xp.finfo(type_.dtype) + + +def iinfo(type_: DType | Array, /, xp: Namespace) -> Any: + try: + return xp.iinfo(type_) + except (ValueError, TypeError): + return xp.iinfo(type_.dtype) + + +__all__ = [ + "arange", + "empty", + "empty_like", + "eye", + "full", + "full_like", + "linspace", + "ones", + "ones_like", + "zeros", + "zeros_like", + "UniqueAllResult", + "UniqueCountsResult", + "UniqueInverseResult", + "unique_all", + "unique_counts", + "unique_inverse", + "unique_values", + "std", + "var", + "cumulative_sum", + "cumulative_prod", + "clip", + "permute_dims", + "reshape", + "argsort", + "sort", + "nonzero", + "matmul", + "matrix_transpose", + "tensordot", + "vecdot", + "isdtype", + "unstack", + "sign", + "finfo", + "iinfo", +] + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/common/_fft.py b/array_api_compat/common/_fft.py index e5caebef..18839d37 100644 --- a/array_api_compat/common/_fft.py +++ b/array_api_compat/common/_fft.py @@ -1,149 +1,150 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Union, Optional, Literal +from collections.abc import Sequence +from typing import Literal, TypeAlias -if TYPE_CHECKING: - from ._typing import Device, ndarray, DType - from collections.abc import Sequence +from ._typing import Array, Device, DType, Namespace + +_Norm: TypeAlias = Literal["backward", "ortho", "forward"] # Note: NumPy fft functions improperly upcast float32 and complex64 to # complex128, which is why we require wrapping them all here. def fft( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, - n: Optional[int] = None, + n: int | None = None, axis: int = -1, - norm: Literal["backward", "ortho", "forward"] = "backward", -) -> ndarray: + norm: _Norm = "backward", +) -> Array: res = xp.fft.fft(x, n=n, axis=axis, norm=norm) if x.dtype in [xp.float32, xp.complex64]: return res.astype(xp.complex64) return res def ifft( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, - n: Optional[int] = None, + n: int | None = None, axis: int = -1, - norm: Literal["backward", "ortho", "forward"] = "backward", -) -> ndarray: + norm: _Norm = "backward", +) -> Array: res = xp.fft.ifft(x, n=n, axis=axis, norm=norm) if x.dtype in [xp.float32, xp.complex64]: return res.astype(xp.complex64) return res def fftn( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, - s: Sequence[int] = None, - axes: Sequence[int] = None, - norm: Literal["backward", "ortho", "forward"] = "backward", -) -> ndarray: + s: Sequence[int] | None = None, + axes: Sequence[int] | None = None, + norm: _Norm = "backward", +) -> Array: res = xp.fft.fftn(x, s=s, axes=axes, norm=norm) if x.dtype in [xp.float32, xp.complex64]: return res.astype(xp.complex64) return res def ifftn( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, - s: Sequence[int] = None, - axes: Sequence[int] = None, - norm: Literal["backward", "ortho", "forward"] = "backward", -) -> ndarray: + s: Sequence[int] | None = None, + axes: Sequence[int] | None = None, + norm: _Norm = "backward", +) -> Array: res = xp.fft.ifftn(x, s=s, axes=axes, norm=norm) if x.dtype in [xp.float32, xp.complex64]: return res.astype(xp.complex64) return res def rfft( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, - n: Optional[int] = None, + n: int | None = None, axis: int = -1, - norm: Literal["backward", "ortho", "forward"] = "backward", -) -> ndarray: + norm: _Norm = "backward", +) -> Array: res = xp.fft.rfft(x, n=n, axis=axis, norm=norm) if x.dtype == xp.float32: return res.astype(xp.complex64) return res def irfft( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, - n: Optional[int] = None, + n: int | None = None, axis: int = -1, - norm: Literal["backward", "ortho", "forward"] = "backward", -) -> ndarray: + norm: _Norm = "backward", +) -> Array: res = xp.fft.irfft(x, n=n, axis=axis, norm=norm) if x.dtype == xp.complex64: return res.astype(xp.float32) return res def rfftn( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, - s: Sequence[int] = None, - axes: Sequence[int] = None, - norm: Literal["backward", "ortho", "forward"] = "backward", -) -> ndarray: + s: Sequence[int] | None = None, + axes: Sequence[int] | None = None, + norm: _Norm = "backward", +) -> Array: res = xp.fft.rfftn(x, s=s, axes=axes, norm=norm) if x.dtype == xp.float32: return res.astype(xp.complex64) return res def irfftn( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, - s: Sequence[int] = None, - axes: Sequence[int] = None, - norm: Literal["backward", "ortho", "forward"] = "backward", -) -> ndarray: + s: Sequence[int] | None = None, + axes: Sequence[int] | None = None, + norm: _Norm = "backward", +) -> Array: res = xp.fft.irfftn(x, s=s, axes=axes, norm=norm) if x.dtype == xp.complex64: return res.astype(xp.float32) return res def hfft( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, - n: Optional[int] = None, + n: int | None = None, axis: int = -1, - norm: Literal["backward", "ortho", "forward"] = "backward", -) -> ndarray: + norm: _Norm = "backward", +) -> Array: res = xp.fft.hfft(x, n=n, axis=axis, norm=norm) if x.dtype in [xp.float32, xp.complex64]: return res.astype(xp.float32) return res def ihfft( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, - n: Optional[int] = None, + n: int | None = None, axis: int = -1, - norm: Literal["backward", "ortho", "forward"] = "backward", -) -> ndarray: + norm: _Norm = "backward", +) -> Array: res = xp.fft.ihfft(x, n=n, axis=axis, norm=norm) if x.dtype in [xp.float32, xp.complex64]: return res.astype(xp.complex64) @@ -152,12 +153,12 @@ def ihfft( def fftfreq( n: int, /, - xp, + xp: Namespace, *, d: float = 1.0, - dtype: Optional[DType] = None, - device: Optional[Device] = None -) -> ndarray: + dtype: DType | None = None, + device: Device | None = None, +) -> Array: if device not in ["cpu", None]: raise ValueError(f"Unsupported device {device!r}") res = xp.fft.fftfreq(n, d=d) @@ -168,12 +169,12 @@ def fftfreq( def rfftfreq( n: int, /, - xp, + xp: Namespace, *, d: float = 1.0, - dtype: Optional[DType] = None, - device: Optional[Device] = None -) -> ndarray: + dtype: DType | None = None, + device: Device | None = None, +) -> Array: if device not in ["cpu", None]: raise ValueError(f"Unsupported device {device!r}") res = xp.fft.rfftfreq(n, d=d) @@ -181,10 +182,14 @@ def rfftfreq( return res.astype(dtype) return res -def fftshift(x: ndarray, /, xp, *, axes: Union[int, Sequence[int]] = None) -> ndarray: +def fftshift( + x: Array, /, xp: Namespace, *, axes: int | Sequence[int] | None = None +) -> Array: return xp.fft.fftshift(x, axes=axes) -def ifftshift(x: ndarray, /, xp, *, axes: Union[int, Sequence[int]] = None) -> ndarray: +def ifftshift( + x: Array, /, xp: Namespace, *, axes: int | Sequence[int] | None = None +) -> Array: return xp.fft.ifftshift(x, axes=axes) __all__ = [ @@ -203,3 +208,6 @@ def ifftshift(x: ndarray, /, xp, *, axes: Union[int, Sequence[int]] = None) -> n "fftshift", "ifftshift", ] + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 791edb81..d5342658 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -5,35 +5,94 @@ that are in __all__ are intended as additional helper functions for use by end users of the compat library. """ + from __future__ import annotations -from typing import TYPE_CHECKING +import enum +import inspect +import math +import sys +import warnings +from collections.abc import Collection, Hashable +from functools import lru_cache +from typing import ( + TYPE_CHECKING, + Any, + Final, + Literal, + SupportsIndex, + TypeAlias, + TypeGuard, + cast, + overload, +) + +from ._typing import Array, Device, HasShape, Namespace, SupportsArrayNamespace if TYPE_CHECKING: - from typing import Optional, Union, Any - from ._typing import Array, Device, Namespace + import cupy as cp + import dask.array as da + import jax + import ndonnx as ndx + import numpy as np + import numpy.typing as npt + import sparse + import torch -import sys -import math -import inspect -import warnings + # TODO: import from typing (requires Python >=3.13) + from typing_extensions import TypeIs + + _ZeroGradientArray: TypeAlias = npt.NDArray[np.void] -def _is_jax_zero_gradient_array(x: object) -> bool: + _ArrayApiObj: TypeAlias = ( + npt.NDArray[Any] + | cp.ndarray + | da.Array + | jax.Array + | ndx.Array + | sparse.SparseArray + | torch.Tensor + | SupportsArrayNamespace[Any] + ) + +_API_VERSIONS_OLD: Final = frozenset({"2021.12", "2022.12", "2023.12", "2024.12"}) +_API_VERSIONS: Final = _API_VERSIONS_OLD | frozenset({"2025.12"}) + + +@lru_cache(100) +def _issubclass_fast(cls: type, modname: str, clsname: str) -> bool: + try: + mod = sys.modules[modname] + except KeyError: + return False + parent_cls = getattr(mod, clsname) + return issubclass(cls, parent_cls) + + +def _is_jax_zero_gradient_array(x: object) -> TypeGuard[_ZeroGradientArray]: """Return True if `x` is a zero-gradient array. These arrays are a design quirk of Jax that may one day be removed. See https://github.com/google/jax/issues/20620. """ - if 'numpy' not in sys.modules or 'jax' not in sys.modules: + # Fast exit + try: + dtype = x.dtype # type: ignore[attr-defined] + except AttributeError: + return False + cls = cast(Hashable, type(dtype)) + if not _issubclass_fast(cls, "numpy.dtypes", "VoidDType"): return False - import numpy as np - import jax + if "jax" not in sys.modules: + return False - return isinstance(x, np.ndarray) and x.dtype == jax.float0 + import jax + # jax.float0 is a np.dtype([('float0', 'V')]) + return dtype == jax.float0 -def is_numpy_array(x: object) -> bool: +def is_numpy_array(x: object) -> TypeIs[npt.NDArray[Any]]: """ Return True if `x` is a NumPy array. @@ -54,15 +113,12 @@ def is_numpy_array(x: object) -> bool: is_jax_array is_pydata_sparse_array """ - # Avoid importing NumPy if it isn't already - if 'numpy' not in sys.modules: - return False - - import numpy as np - # TODO: Should we reject ndarray subclasses? - return (isinstance(x, (np.ndarray, np.generic)) - and not _is_jax_zero_gradient_array(x)) + cls = cast(Hashable, type(x)) + return ( + _issubclass_fast(cls, "numpy", "ndarray") + or _issubclass_fast(cls, "numpy", "generic") + ) and not _is_jax_zero_gradient_array(x) def is_cupy_array(x: object) -> bool: @@ -86,17 +142,11 @@ def is_cupy_array(x: object) -> bool: is_jax_array is_pydata_sparse_array """ - # Avoid importing CuPy if it isn't already - if 'cupy' not in sys.modules: - return False - - import cupy as cp - - # TODO: Should we reject ndarray subclasses? - return isinstance(x, cp.ndarray) + cls = cast(Hashable, type(x)) + return _issubclass_fast(cls, "cupy", "ndarray") -def is_torch_array(x: object) -> bool: +def is_torch_array(x: object) -> TypeIs[torch.Tensor]: """ Return True if `x` is a PyTorch tensor. @@ -114,17 +164,11 @@ def is_torch_array(x: object) -> bool: is_jax_array is_pydata_sparse_array """ - # Avoid importing torch if it isn't already - if 'torch' not in sys.modules: - return False - - import torch - - # TODO: Should we reject ndarray subclasses? - return isinstance(x, torch.Tensor) + cls = cast(Hashable, type(x)) + return _issubclass_fast(cls, "torch", "Tensor") -def is_ndonnx_array(x: object) -> bool: +def is_ndonnx_array(x: object) -> TypeIs[ndx.Array]: """ Return True if `x` is a ndonnx Array. @@ -143,16 +187,11 @@ def is_ndonnx_array(x: object) -> bool: is_jax_array is_pydata_sparse_array """ - # Avoid importing torch if it isn't already - if 'ndonnx' not in sys.modules: - return False - - import ndonnx as ndx - - return isinstance(x, ndx.Array) + cls = cast(Hashable, type(x)) + return _issubclass_fast(cls, "ndonnx", "Array") -def is_dask_array(x: object) -> bool: +def is_dask_array(x: object) -> TypeIs[da.Array]: """ Return True if `x` is a dask.array Array. @@ -171,16 +210,11 @@ def is_dask_array(x: object) -> bool: is_jax_array is_pydata_sparse_array """ - # Avoid importing dask if it isn't already - if 'dask.array' not in sys.modules: - return False - - import dask.array + cls = cast(Hashable, type(x)) + return _issubclass_fast(cls, "dask.array", "Array") - return isinstance(x, dask.array.Array) - -def is_jax_array(x: object) -> bool: +def is_jax_array(x: object) -> TypeIs[jax.Array]: """ Return True if `x` is a JAX array. @@ -200,16 +234,21 @@ def is_jax_array(x: object) -> bool: is_dask_array is_pydata_sparse_array """ - # Avoid importing jax if it isn't already - if 'jax' not in sys.modules: - return False - - import jax - - return isinstance(x, jax.Array) or _is_jax_zero_gradient_array(x) - - -def is_pydata_sparse_array(x) -> bool: + cls = cast(Hashable, type(x)) + # We test for jax.core.Tracer here to identify jax arrays during jit tracing. From jax 0.8.2 on, + # tracers are not a subclass of jax.Array anymore. Note that tracers can also represent + # non-array values and a fully correct implementation would need to use isinstance checks. Since + # we use hash-based caching with type names as keys, we cannot use instance checks without + # losing performance here. For more information, see + # https://github.com/data-apis/array-api-compat/pull/369 and the corresponding issue. + return ( + _issubclass_fast(cls, "jax", "Array") + or _issubclass_fast(cls, "jax.core", "Tracer") + or _is_jax_zero_gradient_array(x) + ) + + +def is_pydata_sparse_array(x: object) -> TypeIs[sparse.SparseArray]: """ Return True if `x` is an array from the `sparse` package. @@ -229,17 +268,12 @@ def is_pydata_sparse_array(x) -> bool: is_dask_array is_jax_array """ - # Avoid importing jax if it isn't already - if 'sparse' not in sys.modules: - return False - - import sparse - # TODO: Account for other backends. - return isinstance(x, sparse.SparseArray) + cls = cast(Hashable, type(x)) + return _issubclass_fast(cls, "sparse", "SparseArray") -def is_array_api_obj(x: object) -> bool: +def is_array_api_obj(x: object) -> TypeGuard[_ArrayApiObj]: """ Return True if `x` is an array API compatible array object. @@ -254,21 +288,41 @@ def is_array_api_obj(x: object) -> bool: is_dask_array is_jax_array """ - return is_numpy_array(x) \ - or is_cupy_array(x) \ - or is_torch_array(x) \ - or is_dask_array(x) \ - or is_jax_array(x) \ - or is_pydata_sparse_array(x) \ - or hasattr(x, '__array_namespace__') + try: + # TODO: drop this check after np.matrix is gone + if _issubclass_fast(type(x), "numpy", "matrix"): + return False + except Exception: + pass + return ( + hasattr(x, '__array_namespace__') + or _is_array_api_cls(cast(Hashable, type(x))) + ) + + +@lru_cache(100) +def _is_array_api_cls(cls: type) -> bool: + return ( + # TODO: drop support for numpy<2 which didn't have __array_namespace__ + _issubclass_fast(cls, "numpy", "ndarray") + or _issubclass_fast(cls, "numpy", "generic") + or _issubclass_fast(cls, "cupy", "ndarray") + or _issubclass_fast(cls, "torch", "Tensor") + or _issubclass_fast(cls, "dask.array", "Array") + or _issubclass_fast(cls, "sparse", "SparseArray") + # TODO: drop support for jax<0.4.32 which didn't have __array_namespace__ + or _issubclass_fast(cls, "jax", "Array") + or _issubclass_fast(cls, "jax.core", "Tracer") # see is_jax_array for limitations + ) def _compat_module_name() -> str: - assert __name__.endswith('.common._helpers') - return __name__.removesuffix('.common._helpers') + assert __name__.endswith(".common._helpers") + return __name__.removesuffix(".common._helpers") -def is_numpy_namespace(xp) -> bool: +@lru_cache(100) +def is_numpy_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is a NumPy namespace. @@ -286,10 +340,11 @@ def is_numpy_namespace(xp) -> bool: is_pydata_sparse_namespace is_array_api_strict_namespace """ - return xp.__name__ in {'numpy', _compat_module_name() + '.numpy'} + return xp.__name__ in {"numpy", _compat_module_name() + ".numpy"} -def is_cupy_namespace(xp) -> bool: +@lru_cache(100) +def is_cupy_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is a CuPy namespace. @@ -307,10 +362,11 @@ def is_cupy_namespace(xp) -> bool: is_pydata_sparse_namespace is_array_api_strict_namespace """ - return xp.__name__ in {'cupy', _compat_module_name() + '.cupy'} + return xp.__name__ in {"cupy", _compat_module_name() + ".cupy"} -def is_torch_namespace(xp) -> bool: +@lru_cache(100) +def is_torch_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is a PyTorch namespace. @@ -328,10 +384,10 @@ def is_torch_namespace(xp) -> bool: is_pydata_sparse_namespace is_array_api_strict_namespace """ - return xp.__name__ in {'torch', _compat_module_name() + '.torch'} + return xp.__name__ in {"torch", _compat_module_name() + ".torch"} -def is_ndonnx_namespace(xp) -> bool: +def is_ndonnx_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is an NDONNX namespace. @@ -347,10 +403,11 @@ def is_ndonnx_namespace(xp) -> bool: is_pydata_sparse_namespace is_array_api_strict_namespace """ - return xp.__name__ == 'ndonnx' + return xp.__name__ == "ndonnx" -def is_dask_namespace(xp) -> bool: +@lru_cache(100) +def is_dask_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is a Dask namespace. @@ -368,10 +425,10 @@ def is_dask_namespace(xp) -> bool: is_pydata_sparse_namespace is_array_api_strict_namespace """ - return xp.__name__ in {'dask.array', _compat_module_name() + '.dask.array'} + return xp.__name__ in {"dask.array", _compat_module_name() + ".dask.array"} -def is_jax_namespace(xp) -> bool: +def is_jax_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is a JAX namespace. @@ -390,10 +447,10 @@ def is_jax_namespace(xp) -> bool: is_pydata_sparse_namespace is_array_api_strict_namespace """ - return xp.__name__ in {'jax.numpy', 'jax.experimental.array_api'} + return xp.__name__ in {"jax.numpy", "jax.experimental.array_api"} -def is_pydata_sparse_namespace(xp) -> bool: +def is_pydata_sparse_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is a pydata/sparse namespace. @@ -409,10 +466,10 @@ def is_pydata_sparse_namespace(xp) -> bool: is_jax_namespace is_array_api_strict_namespace """ - return xp.__name__ == 'sparse' + return xp.__name__ == "sparse" -def is_array_api_strict_namespace(xp) -> bool: +def is_array_api_strict_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is an array-api-strict namespace. @@ -428,18 +485,105 @@ def is_array_api_strict_namespace(xp) -> bool: is_jax_namespace is_pydata_sparse_namespace """ - return xp.__name__ == 'array_api_strict' + return xp.__name__ == "array_api_strict" -def _check_api_version(api_version: str) -> None: - if api_version in ['2021.12', '2022.12', '2023.12']: - warnings.warn(f"The {api_version} version of the array API specification was requested but the returned namespace is actually version 2024.12") - elif api_version is not None and api_version not in ['2021.12', '2022.12', - '2023.12', '2024.12']: - raise ValueError("Only the 2024.12 version of the array API specification is currently supported") +def _check_api_version(api_version: str | None) -> None: + if api_version in _API_VERSIONS_OLD: + warnings.warn( + f"The {api_version} version of the array API specification was requested but the returned namespace is actually version 2025.12" + ) + elif api_version is not None and api_version not in _API_VERSIONS: + raise ValueError( + "Only the 2025.12 version of the array API specification is currently supported" + ) -def array_namespace(*xs, api_version=None, use_compat=None) -> Namespace: +class _ClsToXPInfo(enum.Enum): + SCALAR = 0 + MAYBE_JAX_ZERO_GRADIENT = 1 + + +@lru_cache(100) +def _cls_to_namespace( + cls: type, + api_version: str | None, + use_compat: bool | None, +) -> tuple[Namespace | None, _ClsToXPInfo | None]: + if use_compat not in (None, True, False): + raise ValueError("use_compat must be None, True, or False") + _use_compat = use_compat in (None, True) + cls_ = cast(Hashable, cls) # Make mypy happy + + if ( + _issubclass_fast(cls_, "numpy", "ndarray") + or _issubclass_fast(cls_, "numpy", "generic") + ): + if use_compat is True: + _check_api_version(api_version) + from .. import numpy as xp + elif use_compat is False: + import numpy as xp # type: ignore[no-redef] + else: + # NumPy 2.0+ have __array_namespace__; however they are not + # yet fully array API compatible. + from .. import numpy as xp # type: ignore[no-redef] + return xp, _ClsToXPInfo.MAYBE_JAX_ZERO_GRADIENT + + # Note: this must happen _after_ the test for np.generic, + # because np.float64 and np.complex128 are subclasses of float and complex. + if issubclass(cls, int | float | complex | type(None)): + return None, _ClsToXPInfo.SCALAR + + if _issubclass_fast(cls_, "cupy", "ndarray"): + if _use_compat: + _check_api_version(api_version) + from .. import cupy as xp # type: ignore[no-redef] + else: + import cupy as xp # type: ignore[no-redef] + return xp, None + + if _issubclass_fast(cls_, "torch", "Tensor"): + if _use_compat: + _check_api_version(api_version) + from .. import torch as xp # type: ignore[no-redef] + else: + import torch as xp # type: ignore[no-redef] + return xp, None + + if _issubclass_fast(cls_, "dask.array", "Array"): + if _use_compat: + _check_api_version(api_version) + from ..dask import array as xp # type: ignore[no-redef] + else: + import dask.array as xp # type: ignore[no-redef] + return xp, None + + # Backwards compatibility for jax<0.4.32 + if _issubclass_fast(cls_, "jax", "Array"): + return _jax_namespace(api_version, use_compat), None + + return None, None + + +def _jax_namespace(api_version: str | None, use_compat: bool | None) -> Namespace: + if use_compat: + raise ValueError("JAX does not have an array-api-compat wrapper") + import jax.numpy as jnp + if not hasattr(jnp, "__array_namespace_info__"): + # JAX v0.4.32 and newer implements the array API directly in jax.numpy. + # For older JAX versions, it is available via jax.experimental.array_api. + # jnp.Array objects gain the __array_namespace__ method. + import jax.experimental.array_api # noqa: F401 + # Test api_version + return jnp.empty(0).__array_namespace__(api_version=api_version) + + +def array_namespace( + *xs: Array | complex | None, + api_version: str | None = None, + use_compat: bool | None = None, +) -> Namespace: """ Get the array API compatible namespace for the arrays `xs`. @@ -451,11 +595,11 @@ def array_namespace(*xs, api_version=None, use_compat=None) -> Namespace: api_version: str The newest version of the spec that you need support for (currently - the compat library wrapped APIs support v2024.12). + the compat library wrapped APIs support v2025.12). use_compat: bool or None If None (the default), the native namespace will be returned if it is - already array API compatible, otherwise a compat wrapper is used. If + already array API compatible; otherwise, a compat wrapper is used. If True, the compat library wrapped library will be returned. If False, the native library namespace is returned. @@ -503,117 +647,97 @@ def your_function(x, y): is_pydata_sparse_array """ - if use_compat not in [None, True, False]: - raise ValueError("use_compat must be None, True, or False") - - _use_compat = use_compat in [None, True] - - namespaces = set() + namespaces: list[Namespace] = [] for x in xs: - if is_numpy_array(x): - from .. import numpy as numpy_namespace - import numpy as np - if use_compat is True: - _check_api_version(api_version) - namespaces.add(numpy_namespace) - elif use_compat is False: - namespaces.add(np) - else: - # numpy 2.0+ have __array_namespace__, however, they are not yet fully array API - # compatible. - namespaces.add(numpy_namespace) - elif is_cupy_array(x): - if _use_compat: - _check_api_version(api_version) - from .. import cupy as cupy_namespace - namespaces.add(cupy_namespace) - else: - import cupy as cp - namespaces.add(cp) - elif is_torch_array(x): - if _use_compat: - _check_api_version(api_version) - from .. import torch as torch_namespace - namespaces.add(torch_namespace) - else: - import torch - namespaces.add(torch) - elif is_dask_array(x): - if _use_compat: - _check_api_version(api_version) - from ..dask import array as dask_namespace - namespaces.add(dask_namespace) - else: - import dask.array as da - namespaces.add(da) - elif is_jax_array(x): - if use_compat is True: - _check_api_version(api_version) - raise ValueError("JAX does not have an array-api-compat wrapper") - elif use_compat is False: - import jax.numpy as jnp - else: - # JAX v0.4.32 and newer implements the array API directly in jax.numpy. - # For older JAX versions, it is available via jax.experimental.array_api. - import jax.numpy - if hasattr(jax.numpy, "__array_api_version__"): - jnp = jax.numpy - else: - import jax.experimental.array_api as jnp - namespaces.add(jnp) - elif is_pydata_sparse_array(x): - if use_compat is True: - _check_api_version(api_version) - raise ValueError("`sparse` does not have an array-api-compat wrapper") - else: - import sparse - # `sparse` is already an array namespace. We do not have a wrapper - # submodule for it. - namespaces.add(sparse) - elif hasattr(x, '__array_namespace__'): - if use_compat is True: - raise ValueError("The given array does not have an array-api-compat wrapper") - namespaces.add(x.__array_namespace__(api_version=api_version)) - elif isinstance(x, (bool, int, float, complex, type(None))): + xp, info = _cls_to_namespace(cast(Hashable, type(x)), api_version, use_compat) + if info is _ClsToXPInfo.SCALAR: continue - else: - # TODO: Support Python scalars? - raise TypeError(f"{type(x).__name__} is not a supported array type") - if not namespaces: - raise TypeError("Unrecognized array input") + if ( + info is _ClsToXPInfo.MAYBE_JAX_ZERO_GRADIENT + and _is_jax_zero_gradient_array(x) + ): + xp = _jax_namespace(api_version, use_compat) + + if xp is None: + get_ns = getattr(x, "__array_namespace__", None) + if get_ns is None: + raise TypeError(f"{type(x).__name__} is not a supported array type") + if use_compat: + raise ValueError( + "The given array does not have an array-api-compat wrapper" + ) + xp = get_ns(api_version=api_version) + + namespaces.append(xp) + + # Use a list of modules to avoid a graph break under torch.compile: + # torch._dynamo.exc.Unsupported: Dynamo cannot determine whether the underlying object is hashable + # Explanation: Dynamo does not know whether the underlying python object for + # PythonModuleVariable( None: # pyright: ignore[reportUnusedFunction] + """ + Validate dummy device on device-less array backends. + + Notes + ----- + This function is also invoked by CuPy, which does have multiple devices + if there are multiple GPUs available. + However, CuPy multi-device support is currently impossible + without using the global device or a context manager: + + https://github.com/data-apis/array-api-compat/pull/293 + """ + if bare_xp is sys.modules.get("numpy"): + if device not in ("cpu", None): raise ValueError(f"Unsupported device for NumPy: {device!r}") + elif bare_xp is sys.modules.get("dask.array"): + if device not in ("cpu", _DASK_DEVICE, None): + raise ValueError(f"Unsupported device for Dask: {device!r}") + + # Placeholder object to represent the dask device # when the array backend is not the CPU. # (since it is not easy to tell which device a dask array is on) class _dask_device: - def __repr__(self): + def __repr__(self) -> Literal["DASK_DEVICE"]: return "DASK_DEVICE" + _DASK_DEVICE = _dask_device() + # device() is not on numpy.ndarray or dask.array and to_device() is not on numpy.ndarray # or cupy.ndarray. They are not included in array objects of this library # because this library just reuses the respective ndarray classes without # wrapping or subclassing them. These helper functions can be used instead of # the wrapper functions for libraries that need to support both NumPy/CuPy and # other libraries that use devices. -def device(x: Array, /) -> Device: +def device(x: _ArrayApiObj, /) -> Device: """ Hardware device the array data resides on. @@ -659,7 +783,7 @@ def device(x: Array, /) -> Device: # Return None in this case. Note that this workaround breaks # the standard and will result in new arrays being created on the # default device instead of the same device as the input array(s). - x_device = getattr(x, 'device', None) + x_device = getattr(x, "device", None) # Older JAX releases had .device() as a method, which has been replaced # with a property in accordance with the standard. if inspect.ismethod(x_device): @@ -668,66 +792,66 @@ def device(x: Array, /) -> Device: return x_device elif is_pydata_sparse_array(x): # `sparse` will gain `.device`, so check for this first. - x_device = getattr(x, 'device', None) + x_device = getattr(x, "device", None) if x_device is not None: return x_device # Everything but DOK has this attr. try: - inner = x.data + inner = x.data # pyright: ignore except AttributeError: return "cpu" # Return the device of the constituent array - return device(inner) - return x.device + return device(inner) # pyright: ignore + return x.device # type: ignore # pyright: ignore + # Prevent shadowing, used below _device = device + # Based on cupy.array_api.Array.to_device -def _cupy_to_device(x, device, /, stream=None): +def _cupy_to_device( + x: cp.ndarray, + device: Device, + /, + stream: int | Any | None = None, +) -> cp.ndarray: import cupy as cp - from cupy.cuda import Device as _Device - from cupy.cuda import stream as stream_module - from cupy_backends.cuda.api import runtime - if device == x.device: - return x - elif device == "cpu": + if device == "cpu": # allowing us to use `to_device(x, "cpu")` # is useful for portable test swapping between # host and device backends return x.get() - elif not isinstance(device, _Device): - raise ValueError(f"Unsupported device {device!r}") - else: - # see cupy/cupy#5985 for the reason how we handle device/stream here - prev_device = runtime.getDevice() - prev_stream: stream_module.Stream = None - if stream is not None: - prev_stream = stream_module.get_current_stream() - # stream can be an int as specified in __dlpack__, or a CuPy stream - if isinstance(stream, int): - stream = cp.cuda.ExternalStream(stream) - elif isinstance(stream, cp.cuda.Stream): - pass - else: - raise ValueError('the input stream is not recognized') - stream.use() - try: - runtime.setDevice(device.id) - arr = x.copy() - finally: - runtime.setDevice(prev_device) - if stream is not None: - prev_stream.use() - return arr - -def _torch_to_device(x, device, /, stream=None): + if not isinstance(device, cp.cuda.Device): + raise TypeError(f"Unsupported device type {device!r}") + + if stream is None: + with device: + return cp.asarray(x) + + # stream can be an int as specified in __dlpack__, or a CuPy stream + if isinstance(stream, int): + stream = cp.cuda.ExternalStream(stream) + elif not isinstance(stream, cp.cuda.Stream): + raise TypeError(f"Unsupported stream type {stream!r}") + + with device, stream: + return cp.asarray(x) + + +def _torch_to_device( + x: torch.Tensor, + device: torch.device | str | int, + /, + stream: int | Any | None = None, +) -> torch.Tensor: if stream is not None: raise NotImplementedError return x.to(device) -def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] = None) -> Array: + +def to_device(x: Array, device: Device, /, *, stream: int | Any | None = None) -> Array: """ Copy the array from the device on which it currently resides to the specified ``device``. @@ -747,7 +871,7 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] a ``device`` object (see the `Device Support `__ section of the array API specification). - stream: Optional[Union[int, Any]] + stream: int | Any | None stream object to use during copy. In addition to the types supported in ``array.__dlpack__``, implementations may choose to support any library-specific stream object with the caveat that any code using @@ -779,7 +903,7 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] if is_numpy_array(x): if stream is not None: raise ValueError("The stream argument to to_device() is not supported") - if device == 'cpu': + if device == "cpu": return x raise ValueError(f"Unsupported device {device!r}") elif is_cupy_array(x): @@ -791,13 +915,14 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] if stream is not None: raise ValueError("The stream argument to to_device() is not supported") # TODO: What if our array is on the GPU already? - if device == 'cpu': + if device == "cpu": return x raise ValueError(f"Unsupported device {device!r}") elif is_jax_array(x): if not hasattr(x, "__array_namespace__"): # In JAX v0.4.31 and older, this import adds to_device method to x... - import jax.experimental.array_api # noqa: F401 + import jax.experimental.array_api # noqa: F401 # pyright: ignore + # ... but only on eager JAX. It won't work inside jax.jit. if not hasattr(x, "to_device"): return x @@ -806,10 +931,14 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] # Perform trivial check to return the same array if # device is same instead of err-ing. return x - return x.to_device(device, stream=stream) + return x.to_device(device, stream=stream) # pyright: ignore -def size(x: Array) -> int | None: +@overload +def size(x: HasShape[Collection[SupportsIndex]]) -> int: ... +@overload +def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None: ... +def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None: """ Return the total number of elements of x. @@ -824,12 +953,26 @@ def size(x: Array) -> int | None: # Lazy API compliant arrays, such as ndonnx, can contain None in their shape if None in x.shape: return None - out = math.prod(x.shape) + out = math.prod(cast("Collection[SupportsIndex]", x.shape)) # dask.array.Array.shape can contain NaN return None if math.isnan(out) else out -def is_writeable_array(x: object) -> bool: +@lru_cache(100) +def _is_writeable_cls(cls: type) -> bool | None: + if ( + _issubclass_fast(cls, "numpy", "generic") + or _issubclass_fast(cls, "jax", "Array") + or _issubclass_fast(cls, "jax.core", "Tracer") # see is_jax_array for limitations + or _issubclass_fast(cls, "sparse", "SparseArray") + ): + return False + if _is_array_api_cls(cls): + return True + return None + + +def is_writeable_array(x: object) -> TypeGuard[_ArrayApiObj]: """ Return False if ``x.__setitem__`` is expected to raise; True otherwise. Return False if `x` is not an array API compatible object. @@ -839,14 +982,36 @@ def is_writeable_array(x: object) -> bool: As there is no standard way to check if an array is writeable without actually writing to it, this function blindly returns True for all unknown array types. """ - if is_numpy_array(x): - return x.flags.writeable - if is_jax_array(x) or is_pydata_sparse_array(x): + cls = cast(Hashable, type(x)) + if _issubclass_fast(cls, "numpy", "ndarray"): + return cast("npt.NDArray", x).flags.writeable + res = _is_writeable_cls(cls) + if res is not None: + return res + return hasattr(x, '__array_namespace__') + + +@lru_cache(100) +def _is_lazy_cls(cls: type) -> bool | None: + if ( + _issubclass_fast(cls, "numpy", "ndarray") + or _issubclass_fast(cls, "numpy", "generic") + or _issubclass_fast(cls, "cupy", "ndarray") + or _issubclass_fast(cls, "torch", "Tensor") + or _issubclass_fast(cls, "sparse", "SparseArray") + ): return False - return is_array_api_obj(x) + if ( + _issubclass_fast(cls, "jax", "Array") + or _issubclass_fast(cls, "jax.core", "Tracer") # see is_jax_array for limitations + or _issubclass_fast(cls, "dask.array", "Array") + or _issubclass_fast(cls, "ndonnx", "Array") + ): + return True + return None -def is_lazy_array(x: object) -> bool: +def is_lazy_array(x: object) -> TypeGuard[_ArrayApiObj]: """Return True if x is potentially a future or it may be otherwise impossible or expensive to eagerly read its contents, regardless of their size, e.g. by calling ``bool(x)`` or ``float(x)``. @@ -859,14 +1024,6 @@ def is_lazy_array(x: object) -> bool: This function errs on the side of caution for array types that may or may not be lazy, e.g. JAX arrays, by always returning True for them. """ - if ( - is_numpy_array(x) - or is_cupy_array(x) - or is_torch_array(x) - or is_pydata_sparse_array(x) - ): - return False - # **JAX note:** while it is possible to determine if you're inside or outside # jax.jit by testing the subclass of a jax.Array object, as well as testing bool() # as we do below for unknown arrays, this is not recommended by JAX best practices. @@ -876,10 +1033,14 @@ def is_lazy_array(x: object) -> bool: # compatibility, is highly detrimental to performance as the whole graph will end # up being computed multiple times. - if is_jax_array(x) or is_dask_array(x) or is_ndonnx_array(x): - return True + # Note: skipping reclassification of JAX zero gradient arrays, as one will + # exclusively get them once they leave a jax.grad JIT context. + cls = cast(Hashable, type(x)) + res = _is_lazy_cls(cls) + if res is not None: + return res - if not is_array_api_obj(x): + if not hasattr(x, "__array_namespace__"): return False # Unknown Array API compatible object. Note that this test may have dire consequences @@ -887,7 +1048,7 @@ def is_lazy_array(x: object) -> bool: # on __bool__ (dask is one such example, which however is special-cased above). # Select a single point of the array - s = size(x) + s = size(cast("HasShape[Collection[SupportsIndex | None]]", x)) if s is None: return True xp = array_namespace(x) @@ -899,7 +1060,7 @@ def is_lazy_array(x: object) -> bool: try: bool(x) return False - # The Array API standard dictactes that __bool__ should raise TypeError if the + # The Array API standard dictates that __bool__ should raise TypeError if the # output cannot be defined. # Here we allow for it to raise arbitrary exceptions, e.g. like Dask does. except Exception: @@ -932,4 +1093,5 @@ def is_lazy_array(x: object) -> bool: "to_device", ] -_all_ignore = ['sys', 'math', 'inspect', 'warnings'] +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/common/_linalg.py b/array_api_compat/common/_linalg.py index bfa1f1b9..14b560d1 100644 --- a/array_api_compat/common/_linalg.py +++ b/array_api_compat/common/_linalg.py @@ -1,85 +1,118 @@ from __future__ import annotations -from typing import TYPE_CHECKING, NamedTuple -if TYPE_CHECKING: - from typing import Literal, Optional, Tuple, Union - from ._typing import ndarray - import math +from typing import Literal, NamedTuple, cast import numpy as np + if np.__version__[0] == "2": from numpy.lib.array_utils import normalize_axis_tuple else: - from numpy.core.numeric import normalize_axis_tuple + from numpy.core.numeric import normalize_axis_tuple # type: ignore[no-redef] -from ._aliases import matmul, matrix_transpose, tensordot, vecdot, isdtype from .._internal import get_xp +from ._aliases import isdtype, matmul, matrix_transpose, tensordot, vecdot +from ._typing import Array, DType, JustFloat, JustInt, Namespace + # These are in the main NumPy namespace but not in numpy.linalg -def cross(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1, **kwargs) -> ndarray: +def cross( + x1: Array, + x2: Array, + /, + xp: Namespace, + *, + axis: int = -1, + **kwargs: object, +) -> Array: return xp.cross(x1, x2, axis=axis, **kwargs) -def outer(x1: ndarray, x2: ndarray, /, xp, **kwargs) -> ndarray: +def outer(x1: Array, x2: Array, /, xp: Namespace, **kwargs: object) -> Array: return xp.outer(x1, x2, **kwargs) class EighResult(NamedTuple): - eigenvalues: ndarray - eigenvectors: ndarray + eigenvalues: Array + eigenvectors: Array + +class EigResult(NamedTuple): + eigenvalues: Array + eigenvectors: Array class QRResult(NamedTuple): - Q: ndarray - R: ndarray + Q: Array + R: Array class SlogdetResult(NamedTuple): - sign: ndarray - logabsdet: ndarray + sign: Array + logabsdet: Array class SVDResult(NamedTuple): - U: ndarray - S: ndarray - Vh: ndarray + U: Array + S: Array + Vh: Array # These functions are the same as their NumPy counterparts except they return # a namedtuple. -def eigh(x: ndarray, /, xp, **kwargs) -> EighResult: +def eigh(x: Array, /, xp: Namespace, **kwargs: object) -> EighResult: return EighResult(*xp.linalg.eigh(x, **kwargs)) -def qr(x: ndarray, /, xp, *, mode: Literal['reduced', 'complete'] = 'reduced', - **kwargs) -> QRResult: +def qr( + x: Array, + /, + xp: Namespace, + *, + mode: Literal["reduced", "complete"] = "reduced", + **kwargs: object, +) -> QRResult: return QRResult(*xp.linalg.qr(x, mode=mode, **kwargs)) -def slogdet(x: ndarray, /, xp, **kwargs) -> SlogdetResult: +def slogdet(x: Array, /, xp: Namespace, **kwargs: object) -> SlogdetResult: return SlogdetResult(*xp.linalg.slogdet(x, **kwargs)) -def svd(x: ndarray, /, xp, *, full_matrices: bool = True, **kwargs) -> SVDResult: +def svd( + x: Array, + /, + xp: Namespace, + *, + full_matrices: bool = True, + **kwargs: object, +) -> SVDResult: return SVDResult(*xp.linalg.svd(x, full_matrices=full_matrices, **kwargs)) # These functions have additional keyword arguments # The upper keyword argument is new from NumPy -def cholesky(x: ndarray, /, xp, *, upper: bool = False, **kwargs) -> ndarray: +def cholesky( + x: Array, + /, + xp: Namespace, + *, + upper: bool = False, + **kwargs: object, +) -> Array: L = xp.linalg.cholesky(x, **kwargs) if upper: U = get_xp(xp)(matrix_transpose)(L) if get_xp(xp)(isdtype)(U.dtype, 'complex floating'): - U = xp.conj(U) + U = xp.conj(U) # pyright: ignore[reportConstantRedefinition] return U return L # The rtol keyword argument of matrix_rank() and pinv() is new from NumPy. # Note that it has a different semantic meaning from tol and rcond. -def matrix_rank(x: ndarray, - /, - xp, - *, - rtol: Optional[Union[float, ndarray]] = None, - **kwargs) -> ndarray: +def matrix_rank( + x: Array, + /, + xp: Namespace, + *, + rtol: float | Array | None = None, + **kwargs: object, +) -> Array: # this is different from xp.linalg.matrix_rank, which supports 1 # dimensional arrays. if x.ndim < 2: raise xp.linalg.LinAlgError("1-dimensional array given. Array must be at least two-dimensional") - S = get_xp(xp)(svdvals)(x, **kwargs) + S: Array = get_xp(xp)(svdvals)(x, **kwargs) if rtol is None: tol = S.max(axis=-1, keepdims=True) * max(x.shape[-2:]) * xp.finfo(S.dtype).eps else: @@ -88,7 +121,14 @@ def matrix_rank(x: ndarray, tol = S.max(axis=-1, keepdims=True)*xp.asarray(rtol)[..., xp.newaxis] return xp.count_nonzero(S > tol, axis=-1) -def pinv(x: ndarray, /, xp, *, rtol: Optional[Union[float, ndarray]] = None, **kwargs) -> ndarray: +def pinv( + x: Array, + /, + xp: Namespace, + *, + rtol: float | Array | None = None, + **kwargs: object, +) -> Array: # this is different from xp.linalg.pinv, which does not multiply the # default tolerance by max(M, N). if rtol is None: @@ -97,15 +137,30 @@ def pinv(x: ndarray, /, xp, *, rtol: Optional[Union[float, ndarray]] = None, **k # These functions are new in the array API spec -def matrix_norm(x: ndarray, /, xp, *, keepdims: bool = False, ord: Optional[Union[int, float, Literal['fro', 'nuc']]] = 'fro') -> ndarray: +def matrix_norm( + x: Array, + /, + xp: Namespace, + *, + keepdims: bool = False, + ord: Literal[1, 2, -1, -2] | JustFloat | Literal["fro", "nuc"] | None = "fro", +) -> Array: return xp.linalg.norm(x, axis=(-2, -1), keepdims=keepdims, ord=ord) # svdvals is not in NumPy (but it is in SciPy). It is equivalent to # xp.linalg.svd(compute_uv=False). -def svdvals(x: ndarray, /, xp) -> Union[ndarray, Tuple[ndarray, ...]]: +def svdvals(x: Array, /, xp: Namespace) -> Array | tuple[Array, ...]: return xp.linalg.svd(x, compute_uv=False) -def vector_norm(x: ndarray, /, xp, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ord: Optional[Union[int, float]] = 2) -> ndarray: +def vector_norm( + x: Array, + /, + xp: Namespace, + *, + axis: int | tuple[int, ...] | None = None, + keepdims: bool = False, + ord: JustInt | JustFloat = 2, +) -> Array: # xp.linalg.norm tries to do a matrix norm whenever axis is a 2-tuple or # when axis=None and the input is 2-D, so to force a vector norm, we make # it so the input is 1-D (for axis=None), or reshape so that norm is done @@ -117,7 +172,10 @@ def vector_norm(x: ndarray, /, xp, *, axis: Optional[Union[int, Tuple[int, ...]] elif isinstance(axis, tuple): # Note: The axis argument supports any number of axes, whereas # xp.linalg.norm() only supports a single axis for vector norm. - normalized_axis = normalize_axis_tuple(axis, x.ndim) + normalized_axis = cast( + "tuple[int, ...]", + normalize_axis_tuple(axis, x.ndim), # pyright: ignore[reportCallIssue] + ) rest = tuple(i for i in range(x.ndim) if i not in normalized_axis) newshape = axis + rest _x = xp.transpose(x, newshape).reshape( @@ -133,8 +191,14 @@ def vector_norm(x: ndarray, /, xp, *, axis: Optional[Union[int, Tuple[int, ...]] # We can't reuse xp.linalg.norm(keepdims) because of the reshape hacks # above to avoid matrix norm logic. shape = list(x.shape) - _axis = normalize_axis_tuple(range(x.ndim) if axis is None else axis, x.ndim) - for i in _axis: + axes = cast( + "tuple[int, ...]", + normalize_axis_tuple( # pyright: ignore[reportCallIssue] + range(x.ndim) if axis is None else axis, + x.ndim, + ), + ) + for i in axes: shape[i] = 1 res = xp.reshape(res, tuple(shape)) @@ -143,14 +207,28 @@ def vector_norm(x: ndarray, /, xp, *, axis: Optional[Union[int, Tuple[int, ...]] # xp.diagonal and xp.trace operate on the first two axes whereas these # operates on the last two -def diagonal(x: ndarray, /, xp, *, offset: int = 0, **kwargs) -> ndarray: +def diagonal(x: Array, /, xp: Namespace, *, offset: int = 0, **kwargs: object) -> Array: return xp.diagonal(x, offset=offset, axis1=-2, axis2=-1, **kwargs) -def trace(x: ndarray, /, xp, *, offset: int = 0, dtype=None, **kwargs) -> ndarray: - return xp.asarray(xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs)) +def trace( + x: Array, + /, + xp: Namespace, + *, + offset: int = 0, + dtype: DType | None = None, + **kwargs: object, +) -> Array: + return xp.asarray( + xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs) + ) __all__ = ['cross', 'matmul', 'outer', 'tensordot', 'EighResult', 'QRResult', 'SlogdetResult', 'SVDResult', 'eigh', 'qr', 'slogdet', 'svd', 'cholesky', 'matrix_rank', 'pinv', 'matrix_norm', 'matrix_transpose', 'svdvals', 'vecdot', 'vector_norm', 'diagonal', 'trace'] + + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/common/_typing.py b/array_api_compat/common/_typing.py index d8acdef7..11b00bd1 100644 --- a/array_api_compat/common/_typing.py +++ b/array_api_compat/common/_typing.py @@ -1,26 +1,189 @@ from __future__ import annotations -__all__ = [ - "NestedSequence", - "SupportsBufferProtocol", -] - -from types import ModuleType +from collections.abc import Mapping +from types import ModuleType as Namespace from typing import ( - Any, - TypeVar, + TYPE_CHECKING, + Literal, Protocol, + TypeAlias, + TypedDict, + TypeVar, + final, ) +if TYPE_CHECKING: + from _typeshed import Incomplete + + SupportsBufferProtocol: TypeAlias = Incomplete + Array: TypeAlias = Incomplete + Device: TypeAlias = Incomplete + DType: TypeAlias = Incomplete +else: + SupportsBufferProtocol = object + Array = object + Device = object + DType = object + + _T_co = TypeVar("_T_co", covariant=True) + +# These "Just" types are equivalent to the `Just` type from the `optype` library, +# apart from them not being `@runtime_checkable`. +# - docs: https://github.com/jorenham/optype/blob/master/README.md#just +# - code: https://github.com/jorenham/optype/blob/master/optype/_core/_just.py +@final +class JustInt(Protocol): # type: ignore[misc] + @property # type: ignore[override] + def __class__(self, /) -> type[int]: ... + @__class__.setter + def __class__(self, value: type[int], /) -> None: ... # pyright: ignore[reportIncompatibleMethodOverride] + + +@final +class JustFloat(Protocol): # type: ignore[misc] + @property # type: ignore[override] + def __class__(self, /) -> type[float]: ... + @__class__.setter + def __class__(self, value: type[float], /) -> None: ... # pyright: ignore[reportIncompatibleMethodOverride] + + +@final +class JustComplex(Protocol): # type: ignore[misc] + @property # type: ignore[override] + def __class__(self, /) -> type[complex]: ... + @__class__.setter + def __class__(self, value: type[complex], /) -> None: ... # pyright: ignore[reportIncompatibleMethodOverride] + + class NestedSequence(Protocol[_T_co]): def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ... def __len__(self, /) -> int: ... -SupportsBufferProtocol = Any -Array = Any -Device = Any -DType = Any -Namespace = ModuleType +class SupportsArrayNamespace(Protocol[_T_co]): + def __array_namespace__(self, /, *, api_version: str | None) -> _T_co: ... + + +class HasShape(Protocol[_T_co]): + @property + def shape(self, /) -> _T_co: ... + + +# Return type of `__array_namespace_info__.default_dtypes` +Capabilities = TypedDict( + "Capabilities", + { + "boolean indexing": bool, + "data-dependent shapes": bool, + "max dimensions": int, + }, +) + +# Return type of `__array_namespace_info__.default_dtypes` +DefaultDTypes = TypedDict( + "DefaultDTypes", + { + "real floating": DType, + "complex floating": DType, + "integral": DType, + "indexing": DType, + }, +) + + +_DTypeKind: TypeAlias = Literal[ + "bool", + "signed integer", + "unsigned integer", + "integral", + "real floating", + "complex floating", + "numeric", +] +# Type of the `kind` parameter in `__array_namespace_info__.dtypes` +DTypeKind: TypeAlias = _DTypeKind | tuple[_DTypeKind, ...] + + +# `__array_namespace_info__.dtypes(kind="bool")` +class DTypesBool(TypedDict): + bool: DType + + +# `__array_namespace_info__.dtypes(kind="signed integer")` +class DTypesSigned(TypedDict): + int8: DType + int16: DType + int32: DType + int64: DType + + +# `__array_namespace_info__.dtypes(kind="unsigned integer")` +class DTypesUnsigned(TypedDict): + uint8: DType + uint16: DType + uint32: DType + uint64: DType + + +# `__array_namespace_info__.dtypes(kind="integral")` +class DTypesIntegral(DTypesSigned, DTypesUnsigned): + pass + + +# `__array_namespace_info__.dtypes(kind="real floating")` +class DTypesReal(TypedDict): + float32: DType + float64: DType + + +# `__array_namespace_info__.dtypes(kind="complex floating")` +class DTypesComplex(TypedDict): + complex64: DType + complex128: DType + + +# `__array_namespace_info__.dtypes(kind="numeric")` +class DTypesNumeric(DTypesIntegral, DTypesReal, DTypesComplex): + pass + + +# `__array_namespace_info__.dtypes(kind=None)` (default) +class DTypesAll(DTypesBool, DTypesNumeric): + pass + + +# `__array_namespace_info__.dtypes(kind=?)` (fallback) +DTypesAny: TypeAlias = Mapping[str, DType] + + +__all__ = [ + "Array", + "Capabilities", + "DType", + "DTypeKind", + "DTypesAny", + "DTypesAll", + "DTypesBool", + "DTypesNumeric", + "DTypesIntegral", + "DTypesSigned", + "DTypesUnsigned", + "DTypesReal", + "DTypesComplex", + "DefaultDTypes", + "Device", + "HasShape", + "Namespace", + "JustInt", + "JustFloat", + "JustComplex", + "NestedSequence", + "SupportsArrayNamespace", + "SupportsBufferProtocol", +] + + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/cupy/__init__.py b/array_api_compat/cupy/__init__.py index 59e01058..246ac872 100644 --- a/array_api_compat/cupy/__init__.py +++ b/array_api_compat/cupy/__init__.py @@ -1,3 +1,4 @@ +from typing import Final from cupy import * # noqa: F403 # from cupy import * doesn't overwrite these builtin names @@ -5,12 +6,19 @@ # These imports may overwrite names from the import * above. from ._aliases import * # noqa: F403 +from ._info import __array_namespace_info__ # noqa: F401 # See the comment in the numpy __init__.py -__import__(__package__ + '.linalg') +__import__(__spec__.parent + '.linalg') +__import__(__spec__.parent + '.fft') -__import__(__package__ + '.fft') +__array_api_version__: Final = '2025.12' -from ..common._helpers import * # noqa: F401,F403 +__all__ = sorted( + {name for name in globals() if not name.startswith("__")} + - {"Final", "_aliases", "_info", "_typing"} + | {"__array_api_version__", "__array_namespace_info__", "linalg", "fft"} +) -__array_api_version__ = '2024.12' +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index 30d9fe48..44808ec9 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -1,16 +1,14 @@ from __future__ import annotations +from builtins import bool as py_bool +from typing import Literal + import cupy as cp from ..common import _aliases, _helpers +from ..common._typing import NestedSequence, SupportsBufferProtocol from .._internal import get_xp - -from ._info import __array_namespace_info__ - -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from typing import Optional, Union - from ._typing import ndarray, Device, Dtype, NestedSequence, SupportsBufferProtocol +from ._typing import Array, Device, DType bool = cp.bool_ @@ -56,33 +54,24 @@ argsort = get_xp(cp)(_aliases.argsort) sort = get_xp(cp)(_aliases.sort) nonzero = get_xp(cp)(_aliases.nonzero) -ceil = get_xp(cp)(_aliases.ceil) -floor = get_xp(cp)(_aliases.floor) -trunc = get_xp(cp)(_aliases.trunc) matmul = get_xp(cp)(_aliases.matmul) matrix_transpose = get_xp(cp)(_aliases.matrix_transpose) tensordot = get_xp(cp)(_aliases.tensordot) sign = get_xp(cp)(_aliases.sign) +finfo = get_xp(cp)(_aliases.finfo) +iinfo = get_xp(cp)(_aliases.iinfo) -_copy_default = object() # asarray also adds the copy keyword, which is not present in numpy 1.0. def asarray( - obj: Union[ - ndarray, - bool, - int, - float, - NestedSequence[bool | int | float], - SupportsBufferProtocol, - ], + obj: Array | complex | NestedSequence[complex] | SupportsBufferProtocol, /, *, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, - copy: Optional[bool] = _copy_default, - **kwargs, -) -> ndarray: + dtype: DType | None = None, + device: Device | None = None, + copy: py_bool | None = None, + **kwargs: object, +) -> Array: """ Array API compatibility wrapper for asarray(). @@ -90,35 +79,23 @@ def asarray( specification for more details. """ with cp.cuda.Device(device): - # cupy is like NumPy 1.26 (except without _CopyMode). See the comments - # in asarray in numpy/_aliases.py. - if copy is not _copy_default: - # A future version of CuPy will change the meaning of copy=False - # to mean no-copy. We don't know for certain what version it will - # be yet, so to avoid breaking that version, we use a different - # default value for copy so asarray(obj) with no copy kwarg will - # always do the copy-if-needed behavior. - - # This will still need to be updated to remove the - # NotImplementedError for copy=False, but at least this won't - # break the default or existing behavior. - if copy is None: - copy = False - elif copy is False: - raise NotImplementedError("asarray(copy=False) is not yet supported in cupy") - kwargs['copy'] = copy - - return cp.array(obj, dtype=dtype, **kwargs) + if copy is None: + return cp.asarray(obj, dtype=dtype, **kwargs) + else: + res = cp.array(obj, dtype=dtype, copy=copy, **kwargs) + if not copy and res is not obj: + raise ValueError("Unable to avoid copy while creating an array as requested") + return res def astype( - x: ndarray, - dtype: Dtype, + x: Array, + dtype: DType, /, *, - copy: bool = True, - device: Optional[Device] = None, -) -> ndarray: + copy: py_bool = True, + device: Device | None = None, +) -> Array: if device is None: return x.astype(dtype=dtype, copy=copy) out = _helpers.to_device(x.astype(dtype=dtype, copy=False), device) @@ -127,10 +104,10 @@ def astype( # cupy.count_nonzero does not have keepdims def count_nonzero( - x: ndarray, - axis=None, - keepdims=False -) -> ndarray: + x: Array, + axis: int | tuple[int, ...] | None = None, + keepdims: py_bool = False, +) -> Array: result = cp.count_nonzero(x, axis) if keepdims: if axis is None: @@ -138,6 +115,66 @@ def count_nonzero( return cp.expand_dims(result, axis) return result +# ceil, floor, and trunc return integers for integer inputs + +def ceil(x: Array, /) -> Array: + if cp.issubdtype(x.dtype, cp.integer): + return x.copy() + return cp.ceil(x) + + +def floor(x: Array, /) -> Array: + if cp.issubdtype(x.dtype, cp.integer): + return x.copy() + return cp.floor(x) + + +def trunc(x: Array, /) -> Array: + if cp.issubdtype(x.dtype, cp.integer): + return x.copy() + return cp.trunc(x) + + +# take_along_axis: axis defaults to -1 but in cupy (and numpy) axis is a required arg +def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array: + return cp.take_along_axis(x, indices, axis=axis) + + +# https://github.com/cupy/cupy/pull/9582 +def broadcast_arrays(*arrays: Array) -> tuple[Array, ...]: + return tuple(cp.broadcast_arrays(*arrays)) + + +def meshgrid(*arrays: Array, indexing: Literal['xy', 'ij'] = 'xy') -> tuple[Array, ...]: + return tuple(cp.meshgrid(*arrays, indexing=indexing)) + + +# Match https://github.com/cupy/cupy/pull/9512/ until cupy v14 is the minimum +# supported version +def searchsorted( + x1: Array, + x2: Array | int | float, + /, + *, + side: Literal['left', 'right'] = 'left', + sorter: Array | None = None +) -> Array: + if not isinstance(x2, cp.ndarray): + if not isinstance(x2, int | float | complex): + raise NotImplementedError( + 'Only python scalars or ndarrays are supported for x2') + x2 = cp.asarray(x2) + return cp.searchsorted(x1, x2, side, sorter) + + +# CuPy isin does not accept scalars +def isin(x1: Array | int, x2: Array | int, /, *, invert: bool = False, **kwds) -> Array: + if isinstance(x1, int): + x1 = cp.asarray(x1) + if isinstance(x2, int): + x2 = cp.asarray(x2) + return cp.isin(x1, x2, invert=invert, **kwds) + # These functions are completely new here. If the library already has them # (i.e., numpy 2.0), use the library version instead of our wrapper. @@ -156,10 +193,16 @@ def count_nonzero( else: unstack = get_xp(cp)(_aliases.unstack) -__all__ = _aliases.__all__ + ['__array_namespace_info__', 'asarray', 'astype', +__all__ = _aliases.__all__ + ['asarray', 'astype', 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atan2', 'atanh', 'bitwise_left_shift', 'bitwise_invert', 'bitwise_right_shift', - 'bool', 'concat', 'count_nonzero', 'pow', 'sign'] + 'bool', 'concat', 'count_nonzero', 'pow', 'sign', + 'ceil', 'floor', 'trunc', 'take_along_axis', + 'broadcast_arrays', 'meshgrid', + 'searchsorted', 'isin', +] + -_all_ignore = ['cp', 'get_xp'] +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/cupy/_info.py b/array_api_compat/cupy/_info.py index 790621e4..aef10e85 100644 --- a/array_api_compat/cupy/_info.py +++ b/array_api_compat/cupy/_info.py @@ -26,6 +26,7 @@ complex128, ) + class __array_namespace_info__: """ Get the array API inspection namespace for CuPy. @@ -49,7 +50,7 @@ class __array_namespace_info__: Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.default_dtypes() {'real floating': cupy.float64, 'complex floating': cupy.complex128, @@ -94,13 +95,13 @@ def capabilities(self): >>> info = xp.__array_namespace_info__() >>> info.capabilities() {'boolean indexing': True, - 'data-dependent shapes': True} + 'data-dependent shapes': True, + 'max dimensions': 64} """ return { "boolean indexing": True, "data-dependent shapes": True, - # 'max rank' will be part of the 2024.12 standard "max dimensions": 64, } @@ -117,7 +118,7 @@ def default_device(self): Returns ------- - device : str + device : Device The default device used for new CuPy arrays. Examples @@ -126,6 +127,15 @@ def default_device(self): >>> info.default_device() Device(0) + Notes + ----- + This method returns the static default device when CuPy is initialized. + However, the *current* device used by creation functions (``empty`` etc.) + can be changed globally or with a context manager. + + See Also + -------- + https://github.com/data-apis/array-api/issues/835 """ return cuda.Device(0) @@ -312,7 +322,7 @@ def devices(self): Returns ------- - devices : list of str + devices : list[Device] The devices supported by CuPy. See Also @@ -323,4 +333,4 @@ def devices(self): __array_namespace_info__.dtypes """ - return [cuda.Device(i) for i in range(cuda.runtime.getDeviceCount())] + return tuple(cuda.Device(i) for i in range(cuda.runtime.getDeviceCount())) diff --git a/array_api_compat/cupy/_typing.py b/array_api_compat/cupy/_typing.py index f3d9aab6..e5c202dc 100644 --- a/array_api_compat/cupy/_typing.py +++ b/array_api_compat/cupy/_typing.py @@ -1,46 +1,30 @@ from __future__ import annotations -__all__ = [ - "ndarray", - "Device", - "Dtype", -] +__all__ = ["Array", "DType", "Device"] -import sys -from typing import ( - Union, - TYPE_CHECKING, -) - -from cupy import ( - ndarray, - dtype, - int8, - int16, - int32, - int64, - uint8, - uint16, - uint32, - uint64, - float32, - float64, -) +from typing import TYPE_CHECKING +import cupy as cp +from cupy import ndarray as Array from cupy.cuda.device import Device -if TYPE_CHECKING or sys.version_info >= (3, 9): - Dtype = dtype[Union[ - int8, - int16, - int32, - int64, - uint8, - uint16, - uint32, - uint64, - float32, - float64, - ]] +if TYPE_CHECKING: + # NumPy 1.x on Python 3.10 fails to parse np.dtype[] + DType = cp.dtype[ + cp.intp + | cp.int8 + | cp.int16 + | cp.int32 + | cp.int64 + | cp.uint8 + | cp.uint16 + | cp.uint32 + | cp.uint64 + | cp.float32 + | cp.float64 + | cp.complex64 + | cp.complex128 + | cp.bool_ + ] else: - Dtype = dtype + DType = cp.dtype diff --git a/array_api_compat/cupy/fft.py b/array_api_compat/cupy/fft.py index 307e0f72..53a9a454 100644 --- a/array_api_compat/cupy/fft.py +++ b/array_api_compat/cupy/fft.py @@ -1,10 +1,11 @@ -from cupy.fft import * # noqa: F403 +from cupy.fft import * # noqa: F403 + # cupy.fft doesn't have __all__. If it is added, replace this with # # from cupy.fft import __all__ as linalg_all -_n = {} -exec('from cupy.fft import *', _n) -del _n['__builtins__'] +_n: dict[str, object] = {} +exec("from cupy.fft import *", _n) +del _n["__builtins__"] fft_all = list(_n) del _n @@ -30,7 +31,6 @@ __all__ = fft_all + _fft.__all__ -del get_xp -del cp -del fft_all -del _fft +def __dir__() -> list[str]: + return __all__ + diff --git a/array_api_compat/cupy/linalg.py b/array_api_compat/cupy/linalg.py index 7fcdd498..4e532f9f 100644 --- a/array_api_compat/cupy/linalg.py +++ b/array_api_compat/cupy/linalg.py @@ -1,13 +1,25 @@ from cupy.linalg import * # noqa: F403 -# cupy.linalg doesn't have __all__. If it is added, replace this with + +# https://github.com/cupy/cupy/issues/9749 +from cupy.linalg import lstsq # noqa: F401 + +# cupy.linalg doesn't have __all__ in cupy<14. If it is added, replace this with # # from cupy.linalg import __all__ as linalg_all -_n = {} +_n: dict[str, object] = {} exec('from cupy.linalg import *', _n) del _n['__builtins__'] -linalg_all = list(_n) +linalg_all = list(_n) + ['lstsq'] del _n +try: + # cupy 14 exports it, cupy 13 does not + from cupy.linalg import annotations # noqa: F401 + linalg_all += ['annotations'] +except ImportError: + pass + + from ..common import _linalg from .._internal import get_xp @@ -43,7 +55,8 @@ __all__ = linalg_all + _linalg.__all__ -del get_xp -del cp -del linalg_all -del _linalg +# cupy 13 does not have __all__, cupy 14 has it: remove duplicates +__all__ = sorted(set(__all__)) + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/dask/array/__init__.py b/array_api_compat/dask/array/__init__.py index a6e69ad3..d25ae513 100644 --- a/array_api_compat/dask/array/__init__.py +++ b/array_api_compat/dask/array/__init__.py @@ -1,9 +1,26 @@ -from dask.array import * # noqa: F403 +from typing import Final + +from ..._internal import clone_module + +__all__ = clone_module("dask.array", globals()) # These imports may overwrite names from the import * above. -from ._aliases import * # noqa: F403 +from . import _aliases +from ._aliases import * # type: ignore[assignment] # noqa: F403 +from ._info import __array_namespace_info__ # noqa: F401 + +__array_api_version__: Final = "2025.12" +del Final + +# See the comment in the numpy __init__.py +__import__(__spec__.parent + '.linalg') +__import__(__spec__.parent + '.fft') -__array_api_version__ = '2024.12' +__all__ = sorted( + set(__all__) + | set(_aliases.__all__) + | {"__array_api_version__", "__array_namespace_info__", "linalg", "fft"} +) -__import__(__package__ + '.linalg') -__import__(__package__ + '.fft') +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index 80d66281..54d323b2 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -1,49 +1,46 @@ -from __future__ import annotations - -from typing import Callable +# pyright: reportPrivateUsage=false +# pyright: reportUnknownArgumentType=false +# pyright: reportUnknownMemberType=false +# pyright: reportUnknownVariableType=false -from ...common import _aliases, array_namespace +from __future__ import annotations -from ..._internal import get_xp +from builtins import bool as py_bool +from collections.abc import Callable +from typing import TYPE_CHECKING, Any -from ._info import __array_namespace_info__ +if TYPE_CHECKING: + from typing_extensions import TypeIs +import dask.array as da import numpy as np +from numpy import bool_ as bool from numpy import ( - # Dtypes - iinfo, - finfo, - bool_ as bool, + can_cast, + complex64, + complex128, float32, float64, int8, int16, int32, int64, + result_type, uint8, uint16, uint32, uint64, - complex64, - complex128, - can_cast, - result_type, ) -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from typing import Optional, Union - - from ...common._typing import ( - Device, - Dtype, - Array, - NestedSequence, - SupportsBufferProtocol, - ) - -import dask.array as da +from ..._internal import get_xp +from ...common import _aliases, _helpers, array_namespace +from ...common._typing import ( + Array, + Device, + DType, + NestedSequence, + SupportsBufferProtocol, +) isdtype = get_xp(np)(_aliases.isdtype) unstack = get_xp(da)(_aliases.unstack) @@ -52,11 +49,11 @@ # da.astype doesn't respect copy=True def astype( x: Array, - dtype: Dtype, + dtype: DType, /, *, - copy: bool = True, - device: Optional[Device] = None, + copy: py_bool = True, + device: Device | None = None, ) -> Array: """ Array API compatibility wrapper for astype(). @@ -65,6 +62,7 @@ def astype( specification for more details. """ # TODO: respect device keyword? + _helpers._check_device(da, device) if not copy and dtype == x.dtype: return x @@ -79,14 +77,14 @@ def astype( # not pass stop/step as keyword arguments, which will cause # an error with dask def arange( - start: Union[int, float], + start: float, /, - stop: Optional[Union[int, float]] = None, - step: Union[int, float] = 1, + stop: float | None = None, + step: float = 1, *, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, - **kwargs, + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, ) -> Array: """ Array API compatibility wrapper for arange(). @@ -95,8 +93,9 @@ def arange( specification for more details. """ # TODO: respect device keyword? + _helpers._check_device(da, device) - args = [start] + args: list[Any] = [start] if stop is not None: args.append(stop) else: @@ -134,30 +133,22 @@ def arange( matrix_transpose = get_xp(da)(_aliases.matrix_transpose) vecdot = get_xp(da)(_aliases.vecdot) nonzero = get_xp(da)(_aliases.nonzero) -ceil = get_xp(np)(_aliases.ceil) -floor = get_xp(np)(_aliases.floor) -trunc = get_xp(np)(_aliases.trunc) matmul = get_xp(np)(_aliases.matmul) tensordot = get_xp(np)(_aliases.tensordot) sign = get_xp(np)(_aliases.sign) +finfo = get_xp(np)(_aliases.finfo) +iinfo = get_xp(np)(_aliases.iinfo) # asarray also adds the copy keyword, which is not present in numpy 1.0. def asarray( - obj: Union[ - Array, - bool, - int, - float, - NestedSequence[bool | int | float], - SupportsBufferProtocol, - ], + obj: Array | complex | NestedSequence[complex] | SupportsBufferProtocol, /, *, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, - copy: Optional[Union[bool, np._CopyMode]] = None, - **kwargs, + dtype: DType | None = None, + device: Device | None = None, + copy: py_bool | None = None, + **kwargs: object, ) -> Array: """ Array API compatibility wrapper for asarray(). @@ -166,16 +157,17 @@ def asarray( specification for more details. """ # TODO: respect device keyword? + _helpers._check_device(da, device) if isinstance(obj, da.Array): if dtype is not None and dtype != obj.dtype: if copy is False: raise ValueError("Unable to avoid copy when changing dtype") obj = obj.astype(dtype) - return obj.copy() if copy else obj + return obj.copy() if copy else obj # pyright: ignore[reportAttributeAccessIssue] if copy is False: - raise NotImplementedError( + raise ValueError( "Unable to avoid copy when converting a non-dask object to dask" ) @@ -185,22 +177,21 @@ def asarray( return da.from_array(obj) -from dask.array import ( - # Element wise aliases - arccos as acos, - arccosh as acosh, - arcsin as asin, - arcsinh as asinh, - arctan as atan, - arctan2 as atan2, - arctanh as atanh, - left_shift as bitwise_left_shift, - right_shift as bitwise_right_shift, - invert as bitwise_invert, - power as pow, - # Other - concatenate as concat, -) +# Element wise aliases +from dask.array import arccos as acos +from dask.array import arccosh as acosh +from dask.array import arcsin as asin +from dask.array import arcsinh as asinh +from dask.array import arctan as atan +from dask.array import arctan2 as atan2 +from dask.array import arctanh as atanh + +# Other +from dask.array import concatenate as concat +from dask.array import invert as bitwise_invert +from dask.array import left_shift as bitwise_left_shift +from dask.array import power as pow +from dask.array import right_shift as bitwise_right_shift # dask.array.clip does not work unless all three arguments are provided. @@ -210,8 +201,8 @@ def asarray( def clip( x: Array, /, - min: Optional[Union[int, float, Array]] = None, - max: Optional[Union[int, float, Array]] = None, + min: float | Array | None = None, + max: float | Array | None = None, ) -> Array: """ Array API compatibility wrapper for clip(). @@ -220,8 +211,8 @@ def clip( specification for more details. """ - def _isscalar(a): - return isinstance(a, (int, float, type(None))) + def _isscalar(a: float | Array | None, /) -> TypeIs[float | None]: + return a is None or isinstance(a, (int, float)) min_shape = () if _isscalar(min) else min.shape max_shape = () if _isscalar(max) else max.shape @@ -274,7 +265,12 @@ def _ensure_single_chunk(x: Array, axis: int) -> tuple[Array, Callable[[Array], def sort( - x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True + x: Array, + /, + *, + axis: int = -1, + descending: py_bool = False, + stable: py_bool = True, ) -> Array: """ Array API compatibility layer around the lack of sort() in Dask. @@ -304,7 +300,12 @@ def sort( def argsort( - x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True + x: Array, + /, + *, + axis: int = -1, + descending: py_bool = False, + stable: py_bool = True, ) -> Array: """ Array API compatibility layer around the lack of argsort() in Dask. @@ -338,26 +339,31 @@ def argsort( # dask.array.count_nonzero does not have keepdims def count_nonzero( x: Array, - axis=None, - keepdims=False + axis: int | None = None, + keepdims: py_bool = False, ) -> Array: - result = da.count_nonzero(x, axis) - if keepdims: - if axis is None: - return da.reshape(result, [1]*x.ndim) - return da.expand_dims(result, axis) - return result - - - -__all__ = _aliases.__all__ + [ - '__array_namespace_info__', 'asarray', 'astype', 'acos', - 'acosh', 'asin', 'asinh', 'atan', 'atan2', - 'atanh', 'bitwise_left_shift', 'bitwise_invert', - 'bitwise_right_shift', 'concat', 'pow', 'iinfo', 'finfo', 'can_cast', - 'result_type', 'bool', 'float32', 'float64', 'int8', 'int16', 'int32', 'int64', - 'uint8', 'uint16', 'uint32', 'uint64', - 'complex64', 'complex128', 'iinfo', 'finfo', - 'can_cast', 'count_nonzero', 'result_type'] - -_all_ignore = ["Callable", "array_namespace", "get_xp", "da", "np"] + result = da.count_nonzero(x, axis) + if keepdims: + if axis is None: + return da.reshape(result, [1] * x.ndim) + return da.expand_dims(result, axis) + return result + + +__all__ = [ + "count_nonzero", + "bool", + "int8", "int16", "int32", "int64", + "uint8", "uint16", "uint32", "uint64", + "float32", "float64", + "complex64", "complex128", + "asarray", "astype", "can_cast", "result_type", + "pow", + "concat", + "acos", "acosh", "asin", "asinh", "atan", "atan2", "atanh", + "bitwise_left_shift", "bitwise_right_shift", "bitwise_invert", +] # fmt: skip +__all__ += _aliases.__all__ + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/dask/array/_info.py b/array_api_compat/dask/array/_info.py index e15a69f4..3a7285d5 100644 --- a/array_api_compat/dask/array/_info.py +++ b/array_api_compat/dask/array/_info.py @@ -7,25 +7,50 @@ more details. """ + +# pyright: reportPrivateUsage=false + +from __future__ import annotations + +from typing import Literal, TypeAlias, overload + +import dask.array as da +from numpy import bool_ as bool from numpy import ( + complex64, + complex128, dtype, - bool_ as bool, - intp, + float32, + float64, int8, int16, int32, int64, + intp, uint8, uint16, uint32, uint64, - float32, - float64, - complex64, - complex128, ) -from ...common._helpers import _DASK_DEVICE +from ...common._helpers import _DASK_DEVICE, _check_device, _dask_device +from ...common._typing import ( + Capabilities, + DefaultDTypes, + DType, + DTypeKind, + DTypesAll, + DTypesAny, + DTypesBool, + DTypesComplex, + DTypesIntegral, + DTypesNumeric, + DTypesReal, + DTypesSigned, + DTypesUnsigned, +) +Device: TypeAlias = Literal["cpu"] | _dask_device + class __array_namespace_info__: """ @@ -50,7 +75,7 @@ class __array_namespace_info__: Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.default_dtypes() {'real floating': dask.float64, 'complex floating': dask.complex128, @@ -59,20 +84,31 @@ class __array_namespace_info__: """ - __module__ = 'dask.array' + __module__ = "dask.array" - def capabilities(self): + def capabilities(self) -> Capabilities: """ Return a dictionary of array API library capabilities. The resulting dictionary has the following keys: - **"boolean indexing"**: boolean indicating whether an array library - supports boolean indexing. Always ``False`` for Dask. + supports boolean indexing. + + Dask support boolean indexing as long as both the index + and the indexed arrays have known shapes. + Note however that the output .shape and .size properties + will contain a non-compliant math.nan instead of None. - **"data-dependent shapes"**: boolean indicating whether an array - library supports data-dependent output shapes. Always ``False`` for - Dask. + library supports data-dependent output shapes. + + Dask implements unique_values et.al. + Note however that the output .shape and .size properties + will contain a non-compliant math.nan instead of None. + + - **"max dimensions"**: integer indicating the maximum number of + dimensions supported by the array library. See https://data-apis.org/array-api/latest/API_specification/generated/array_api.info.capabilities.html @@ -92,20 +128,20 @@ def capabilities(self): Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.capabilities() {'boolean indexing': True, - 'data-dependent shapes': True} + 'data-dependent shapes': True, + 'max dimensions': 64} """ return { - "boolean indexing": False, - "data-dependent shapes": False, - # 'max rank' will be part of the 2024.12 standard + "boolean indexing": True, + "data-dependent shapes": True, "max dimensions": 64, } - def default_device(self): + def default_device(self) -> Device: """ The default device used for new Dask arrays. @@ -120,19 +156,19 @@ def default_device(self): Returns ------- - device : str + device : Device The default device used for new Dask arrays. Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.default_device() 'cpu' """ return "cpu" - def default_dtypes(self, *, device=None): + def default_dtypes(self, /, *, device: Device | None = None) -> DefaultDTypes: """ The default data types used for new Dask arrays. @@ -163,7 +199,7 @@ def default_dtypes(self, *, device=None): Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.default_dtypes() {'real floating': dask.float64, 'complex floating': dask.complex128, @@ -171,11 +207,7 @@ def default_dtypes(self, *, device=None): 'indexing': dask.int64} """ - if device not in ["cpu", _DASK_DEVICE, None]: - raise ValueError( - 'Device not understood. Only "cpu" or _DASK_DEVICE is allowed, but received:' - f' {device}' - ) + _check_device(da, device) return { "real floating": dtype(float64), "complex floating": dtype(complex128), @@ -183,7 +215,41 @@ def default_dtypes(self, *, device=None): "indexing": dtype(intp), } - def dtypes(self, *, device=None, kind=None): + @overload + def dtypes( + self, /, *, device: Device | None = None, kind: None = None + ) -> DTypesAll: ... + @overload + def dtypes( + self, /, *, device: Device | None = None, kind: Literal["bool"] + ) -> DTypesBool: ... + @overload + def dtypes( + self, /, *, device: Device | None = None, kind: Literal["signed integer"] + ) -> DTypesSigned: ... + @overload + def dtypes( + self, /, *, device: Device | None = None, kind: Literal["unsigned integer"] + ) -> DTypesUnsigned: ... + @overload + def dtypes( + self, /, *, device: Device | None = None, kind: Literal["integral"] + ) -> DTypesIntegral: ... + @overload + def dtypes( + self, /, *, device: Device | None = None, kind: Literal["real floating"] + ) -> DTypesReal: ... + @overload + def dtypes( + self, /, *, device: Device | None = None, kind: Literal["complex floating"] + ) -> DTypesComplex: ... + @overload + def dtypes( + self, /, *, device: Device | None = None, kind: Literal["numeric"] + ) -> DTypesNumeric: ... + def dtypes( + self, /, *, device: Device | None = None, kind: DTypeKind | None = None + ) -> DTypesAny: """ The array API data types supported by Dask. @@ -229,7 +295,7 @@ def dtypes(self, *, device=None, kind=None): Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.dtypes(kind='signed integer') {'int8': dask.int8, 'int16': dask.int16, @@ -237,11 +303,7 @@ def dtypes(self, *, device=None, kind=None): 'int64': dask.int64} """ - if device not in ["cpu", _DASK_DEVICE, None]: - raise ValueError( - 'Device not understood. Only "cpu" or _DASK_DEVICE is allowed, but received:' - f' {device}' - ) + _check_device(da, device) if kind is None: return { "bool": dtype(bool), @@ -311,13 +373,13 @@ def dtypes(self, *, device=None, kind=None): "complex128": dtype(complex128), } if isinstance(kind, tuple): - res = {} + res: dict[str, DType] = {} for k in kind: res.update(self.dtypes(kind=k)) return res raise ValueError(f"unsupported kind: {kind!r}") - def devices(self): + def devices(self) -> tuple[Device]: """ The devices supported by Dask. @@ -325,7 +387,7 @@ def devices(self): Returns ------- - devices : list of str + devices : list[Device] The devices supported by Dask. See Also @@ -337,9 +399,9 @@ def devices(self): Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.devices() ['cpu', DASK_DEVICE] """ - return ["cpu", _DASK_DEVICE] + return ("cpu", _DASK_DEVICE) diff --git a/array_api_compat/dask/array/fft.py b/array_api_compat/dask/array/fft.py index aebd86f7..44b68e73 100644 --- a/array_api_compat/dask/array/fft.py +++ b/array_api_compat/dask/array/fft.py @@ -1,12 +1,6 @@ -from dask.array.fft import * # noqa: F403 -# dask.array.fft doesn't have __all__. If it is added, replace this with -# -# from dask.array.fft import __all__ as linalg_all -_n = {} -exec('from dask.array.fft import *', _n) -del _n['__builtins__'] -fft_all = list(_n) -del _n +from ..._internal import clone_module + +__all__ = clone_module("dask.array.fft", globals()) from ...common import _fft from ..._internal import get_xp @@ -16,9 +10,7 @@ fftfreq = get_xp(da)(_fft.fftfreq) rfftfreq = get_xp(da)(_fft.rfftfreq) -__all__ = [elem for elem in fft_all if elem != "annotations"] + ["fftfreq", "rfftfreq"] +__all__ += ["fftfreq", "rfftfreq"] -del get_xp -del da -del fft_all -del _fft +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/dask/array/linalg.py b/array_api_compat/dask/array/linalg.py index 49c26d8b..a9be5d5f 100644 --- a/array_api_compat/dask/array/linalg.py +++ b/array_api_compat/dask/array/linalg.py @@ -1,33 +1,20 @@ from __future__ import annotations -from ...common import _linalg -from ..._internal import get_xp +from typing import Literal -# Exports -from dask.array.linalg import * # noqa: F403 -from dask.array import outer +import dask.array as da -# These functions are in both the main and linalg namespaces -from dask.array import matmul, tensordot -from ._aliases import matrix_transpose, vecdot +# The `matmul` and `tensordot` functions are in both the main and linalg namespaces +from dask.array import matmul, outer, tensordot -import dask.array as da +# Exports +from ..._internal import clone_module, get_xp +from ...common import _linalg +from ...common._typing import Array -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from ...common._typing import Array - from typing import Literal +__all__ = clone_module("dask.array.linalg", globals()) -# dask.array.linalg doesn't have __all__. If it is added, replace this with -# -# from dask.array.linalg import __all__ as linalg_all -_n = {} -exec('from dask.array.linalg import *', _n) -del _n['__builtins__'] -if 'annotations' in _n: - del _n['annotations'] -linalg_all = list(_n) -del _n +from ._aliases import matrix_transpose, vecdot EighResult = _linalg.EighResult QRResult = _linalg.QRResult @@ -37,8 +24,11 @@ # supports the mode keyword on QR # https://github.com/dask/dask/issues/10388 #qr = get_xp(da)(_linalg.qr) -def qr(x: Array, mode: Literal['reduced', 'complete'] = 'reduced', - **kwargs) -> QRResult: +def qr( # type: ignore[no-redef] + x: Array, + mode: Literal["reduced", "complete"] = "reduced", + **kwargs: object, +) -> QRResult: if mode != "reduced": raise ValueError("dask arrays only support using mode='reduced'") return QRResult(*da.linalg.qr(x, **kwargs)) @@ -51,9 +41,9 @@ def qr(x: Array, mode: Literal['reduced', 'complete'] = 'reduced', # Wrap the svd functions to not pass full_matrices to dask # when full_matrices=False (as that is the default behavior for dask), # and dask doesn't have the full_matrices keyword -def svd(x: Array, full_matrices: bool = True, **kwargs) -> SVDResult: +def svd(x: Array, full_matrices: bool = True, **kwargs: object) -> SVDResult: # type: ignore[no-redef] if full_matrices: - raise ValueError("full_matrics=True is not supported by dask.") + raise ValueError("full_matrices=True is not supported by dask.") return da.linalg.svd(x, coerce_signs=False, **kwargs) def svdvals(x: Array) -> Array: @@ -64,10 +54,11 @@ def svdvals(x: Array) -> Array: vector_norm = get_xp(da)(_linalg.vector_norm) diagonal = get_xp(da)(_linalg.diagonal) -__all__ = linalg_all + ["trace", "outer", "matmul", "tensordot", - "matrix_transpose", "vecdot", "EighResult", - "QRResult", "SlogdetResult", "SVDResult", "qr", - "cholesky", "matrix_rank", "matrix_norm", "svdvals", - "vector_norm", "diagonal"] +__all__ += ["trace", "outer", "matmul", "tensordot", + "matrix_transpose", "vecdot", "EighResult", + "QRResult", "SlogdetResult", "SVDResult", "qr", + "cholesky", "matrix_rank", "matrix_norm", "svdvals", + "vector_norm", "diagonal"] -_all_ignore = ['get_xp', 'da', 'linalg_all'] +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/numpy/__init__.py b/array_api_compat/numpy/__init__.py index 02c55d28..973e993d 100644 --- a/array_api_compat/numpy/__init__.py +++ b/array_api_compat/numpy/__init__.py @@ -1,10 +1,16 @@ -from numpy import * # noqa: F403 +from typing import Final -# from numpy import * doesn't overwrite these builtin names -from numpy import abs, max, min, round # noqa: F401 +from .._internal import clone_module + +# This needs to be loaded explicitly before cloning +import numpy.typing # noqa: F401 + +__all__ = clone_module("numpy", globals()) # These imports may overwrite names from the import * above. -from ._aliases import * # noqa: F403 +from . import _aliases +from ._aliases import * # type: ignore[assignment,no-redef] # noqa: F403 +from ._info import __array_namespace_info__ # noqa: F401 # Don't know why, but we have to do an absolute import to import linalg. If we # instead do @@ -13,18 +19,19 @@ # # It doesn't overwrite np.linalg from above. The import is generated # dynamically so that the library can be vendored. -__import__(__package__ + '.linalg') +__import__(__spec__.parent + ".linalg") -__import__(__package__ + '.fft') +__import__(__spec__.parent + ".fft") -from .linalg import matrix_transpose, vecdot # noqa: F401 +from .linalg import matrix_transpose, vecdot # type: ignore[no-redef] # noqa: F401 -from ..common._helpers import * # noqa: F403 +__array_api_version__: Final = "2025.12" -try: - # Used in asarray(). Not present in older versions. - from numpy import _CopyMode # noqa: F401 -except ImportError: - pass +__all__ = sorted( + set(__all__) + | set(_aliases.__all__) + | {"__array_api_version__", "__array_namespace_info__", "linalg", "fft"} +) -__array_api_version__ = '2024.12' +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index a47f7121..87b3c2f3 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -1,17 +1,16 @@ +# pyright: reportPrivateUsage=false from __future__ import annotations -from ..common import _aliases +from builtins import bool as py_bool +from typing import Any, cast -from .._internal import get_xp - -from ._info import __array_namespace_info__ +import numpy as np -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from typing import Optional, Union - from ._typing import ndarray, Device, Dtype, NestedSequence, SupportsBufferProtocol +from .._internal import get_xp +from ..common import _aliases, _helpers +from ..common._typing import NestedSequence, SupportsBufferProtocol +from ._typing import Array, Device, DType -import numpy as np bool = np.bool_ # Basic renames @@ -56,111 +55,137 @@ argsort = get_xp(np)(_aliases.argsort) sort = get_xp(np)(_aliases.sort) nonzero = get_xp(np)(_aliases.nonzero) -ceil = get_xp(np)(_aliases.ceil) -floor = get_xp(np)(_aliases.floor) -trunc = get_xp(np)(_aliases.trunc) matmul = get_xp(np)(_aliases.matmul) matrix_transpose = get_xp(np)(_aliases.matrix_transpose) tensordot = get_xp(np)(_aliases.tensordot) sign = get_xp(np)(_aliases.sign) +finfo = get_xp(np)(_aliases.finfo) +iinfo = get_xp(np)(_aliases.iinfo) -def _supports_buffer_protocol(obj): - try: - memoryview(obj) - except TypeError: - return False - return True # asarray also adds the copy keyword, which is not present in numpy 1.0. # asarray() is different enough between numpy, cupy, and dask, the logic # complicated enough that it's easier to define it separately for each module # rather than trying to combine everything into one function in common/ def asarray( - obj: Union[ - ndarray, - bool, - int, - float, - NestedSequence[bool | int | float], - SupportsBufferProtocol, - ], + obj: Array | complex | NestedSequence[complex] | SupportsBufferProtocol, /, *, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, - copy: "Optional[Union[bool, np._CopyMode]]" = None, - **kwargs, -) -> ndarray: + dtype: DType | None = None, + device: Device | None = None, + copy: py_bool | None = None, + **kwargs: Any, +) -> Array: """ Array API compatibility wrapper for asarray(). See the corresponding documentation in the array library and/or the array API specification for more details. """ - if device not in ["cpu", None]: - raise ValueError(f"Unsupported device for NumPy: {device!r}") - - if hasattr(np, '_CopyMode'): - if copy is None: - copy = np._CopyMode.IF_NEEDED - elif copy is False: - copy = np._CopyMode.NEVER - elif copy is True: - copy = np._CopyMode.ALWAYS - else: - # Not present in older NumPys. In this case, we cannot really support - # copy=False. - if copy is False: - raise NotImplementedError("asarray(copy=False) requires a newer version of NumPy.") + _helpers._check_device(np, device) + + # None is unsupported in NumPy 1.0, but we can use an internal enum + # False in NumPy 1.0 means None in NumPy 2.0 and in the Array API + if copy is None: + copy = np._CopyMode.IF_NEEDED # type: ignore[assignment,attr-defined] + elif copy is False: + copy = np._CopyMode.NEVER # type: ignore[assignment,attr-defined] return np.array(obj, copy=copy, dtype=dtype, **kwargs) def astype( - x: ndarray, - dtype: Dtype, + x: Array, + dtype: DType, /, *, - copy: bool = True, - device: Optional[Device] = None, -) -> ndarray: + copy: py_bool = True, + device: Device | None = None, +) -> Array: + _helpers._check_device(np, device) return x.astype(dtype=dtype, copy=copy) # count_nonzero returns a python int for axis=None and keepdims=False # https://github.com/numpy/numpy/issues/17562 def count_nonzero( - x : ndarray, - axis=None, - keepdims=False -) -> ndarray: - result = np.count_nonzero(x, axis=axis, keepdims=keepdims) + x: Array, + axis: int | tuple[int, ...] | None = None, + keepdims: py_bool = False, +) -> Array: + # NOTE: this is currently incorrectly typed in numpy, but will be fixed in + # numpy 2.2.5 and 2.3.0: https://github.com/numpy/numpy/pull/28750 + result = cast("Any", np.count_nonzero(x, axis=axis, keepdims=keepdims)) # pyright: ignore[reportArgumentType, reportCallIssue] if axis is None and not keepdims: return np.asarray(result) return result +# take_along_axis: axis defaults to -1 but in numpy axis is a required arg +def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array: + return np.take_along_axis(x, indices, axis=axis) + + +# ceil, floor, and trunc return integers for integer inputs in NumPy < 2 + +def ceil(x: Array, /) -> Array: + if np.__version__ < '2' and np.issubdtype(x.dtype, np.integer): + return x.copy() + return np.ceil(x) + + +def floor(x: Array, /) -> Array: + if np.__version__ < '2' and np.issubdtype(x.dtype, np.integer): + return x.copy() + return np.floor(x) + + +def trunc(x: Array, /) -> Array: + if np.__version__ < '2' and np.issubdtype(x.dtype, np.integer): + return x.copy() + return np.trunc(x) + + # These functions are completely new here. If the library already has them # (i.e., numpy 2.0), use the library version instead of our wrapper. -if hasattr(np, 'vecdot'): +if hasattr(np, "vecdot"): vecdot = np.vecdot else: - vecdot = get_xp(np)(_aliases.vecdot) + vecdot = get_xp(np)(_aliases.vecdot) # type: ignore[assignment] -if hasattr(np, 'isdtype'): +if hasattr(np, "isdtype"): isdtype = np.isdtype else: isdtype = get_xp(np)(_aliases.isdtype) -if hasattr(np, 'unstack'): +if hasattr(np, "unstack"): unstack = np.unstack else: unstack = get_xp(np)(_aliases.unstack) -__all__ = _aliases.__all__ + ['__array_namespace_info__', 'asarray', 'astype', - 'acos', 'acosh', 'asin', 'asinh', 'atan', - 'atan2', 'atanh', 'bitwise_left_shift', - 'bitwise_invert', 'bitwise_right_shift', - 'bool', 'concat', 'count_nonzero', 'pow'] - -_all_ignore = ['np', 'get_xp'] +__all__ = _aliases.__all__ + [ + "asarray", + "astype", + "acos", + "acosh", + "asin", + "asinh", + "atan", + "atan2", + "atanh", + "ceil", + "floor", + "trunc", + "bitwise_left_shift", + "bitwise_invert", + "bitwise_right_shift", + "bool", + "concat", + "count_nonzero", + "pow", + "take_along_axis" +] + + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/numpy/_info.py b/array_api_compat/numpy/_info.py index e706d118..9ba004da 100644 --- a/array_api_compat/numpy/_info.py +++ b/array_api_compat/numpy/_info.py @@ -7,24 +7,29 @@ more details. """ +from __future__ import annotations + +from numpy import bool_ as bool from numpy import ( + complex64, + complex128, dtype, - bool_ as bool, - intp, + float32, + float64, int8, int16, int32, int64, + intp, uint8, uint16, uint32, uint64, - float32, - float64, - complex64, - complex128, ) +from ..common._typing import DefaultDTypes +from ._typing import Device, DType + class __array_namespace_info__: """ @@ -94,13 +99,13 @@ def capabilities(self): >>> info = np.__array_namespace_info__() >>> info.capabilities() {'boolean indexing': True, - 'data-dependent shapes': True} + 'data-dependent shapes': True, + 'max dimensions': 64} """ return { "boolean indexing": True, "data-dependent shapes": True, - # 'max rank' will be part of the 2024.12 standard "max dimensions": 64, } @@ -119,7 +124,7 @@ def default_device(self): Returns ------- - device : str + device : Device The default device used for new NumPy arrays. Examples @@ -131,7 +136,11 @@ def default_device(self): """ return "cpu" - def default_dtypes(self, *, device=None): + def default_dtypes( + self, + *, + device: Device | None = None, + ) -> DefaultDTypes: """ The default data types used for new NumPy arrays. @@ -183,7 +192,12 @@ def default_dtypes(self, *, device=None): "indexing": dtype(intp), } - def dtypes(self, *, device=None, kind=None): + def dtypes( + self, + *, + device: Device | None = None, + kind: str | tuple[str, ...] | None = None, + ) -> dict[str, DType]: """ The array API data types supported by NumPy. @@ -260,7 +274,7 @@ def dtypes(self, *, device=None, kind=None): "complex128": dtype(complex128), } if kind == "bool": - return {"bool": bool} + return {"bool": dtype(bool)} if kind == "signed integer": return { "int8": dtype(int8), @@ -312,13 +326,13 @@ def dtypes(self, *, device=None, kind=None): "complex128": dtype(complex128), } if isinstance(kind, tuple): - res = {} + res: dict[str, DType] = {} for k in kind: res.update(self.dtypes(kind=k)) return res raise ValueError(f"unsupported kind: {kind!r}") - def devices(self): + def devices(self) -> tuple[Device]: """ The devices supported by NumPy. @@ -326,7 +340,7 @@ def devices(self): Returns ------- - devices : list of str + devices : list[Device] The devices supported by NumPy. See Also @@ -343,4 +357,11 @@ def devices(self): ['cpu'] """ - return ["cpu"] + return ("cpu",) + + +__all__ = ["__array_namespace_info__"] + + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/numpy/_typing.py b/array_api_compat/numpy/_typing.py index c5ebb5ab..b5fa188c 100644 --- a/array_api_compat/numpy/_typing.py +++ b/array_api_compat/numpy/_typing.py @@ -1,46 +1,29 @@ from __future__ import annotations -__all__ = [ - "ndarray", - "Device", - "Dtype", -] - -import sys -from typing import ( - Literal, - Union, - TYPE_CHECKING, -) - -from numpy import ( - ndarray, - dtype, - int8, - int16, - int32, - int64, - uint8, - uint16, - uint32, - uint64, - float32, - float64, -) - -Device = Literal["cpu"] -if TYPE_CHECKING or sys.version_info >= (3, 9): - Dtype = dtype[Union[ - int8, - int16, - int32, - int64, - uint8, - uint16, - uint32, - uint64, - float32, - float64, - ]] +from typing import TYPE_CHECKING, Any, Literal, TypeAlias + +import numpy as np + +Device: TypeAlias = Literal["cpu"] + +if TYPE_CHECKING: + + # NumPy 1.x on Python 3.10 fails to parse np.dtype[] + DType: TypeAlias = np.dtype[ + np.bool_ + | np.integer[Any] + | np.float32 + | np.float64 + | np.complex64 + | np.complex128 + ] + Array: TypeAlias = np.ndarray[Any, DType] else: - Dtype = dtype + DType: TypeAlias = np.dtype + Array: TypeAlias = np.ndarray + +__all__ = ["Array", "DType", "Device"] + + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/numpy/fft.py b/array_api_compat/numpy/fft.py index 28667594..a492feb8 100644 --- a/array_api_compat/numpy/fft.py +++ b/array_api_compat/numpy/fft.py @@ -1,10 +1,11 @@ -from numpy.fft import * # noqa: F403 -from numpy.fft import __all__ as fft_all +import numpy as np -from ..common import _fft -from .._internal import get_xp +from .._internal import clone_module -import numpy as np +__all__ = clone_module("numpy.fft", globals()) + +from .._internal import get_xp +from ..common import _fft fft = get_xp(np)(_fft.fft) ifft = get_xp(np)(_fft.ifft) @@ -21,9 +22,9 @@ fftshift = get_xp(np)(_fft.fftshift) ifftshift = get_xp(np)(_fft.ifftshift) -__all__ = fft_all + _fft.__all__ -del get_xp -del np -del fft_all -del _fft +__all__ = sorted(set(__all__) | set(_fft.__all__)) + +def __dir__() -> list[str]: + return __all__ + diff --git a/array_api_compat/numpy/linalg.py b/array_api_compat/numpy/linalg.py index 8f01593b..474efe50 100644 --- a/array_api_compat/numpy/linalg.py +++ b/array_api_compat/numpy/linalg.py @@ -1,18 +1,25 @@ -from numpy.linalg import * # noqa: F403 -from numpy.linalg import __all__ as linalg_all -import numpy as _np +# pyright: reportAttributeAccessIssue=false +# pyright: reportUnknownArgumentType=false +# pyright: reportUnknownMemberType=false +# pyright: reportUnknownVariableType=false +from __future__ import annotations + +import numpy as np + +from .._internal import clone_module, get_xp from ..common import _linalg -from .._internal import get_xp -# These functions are in both the main and linalg namespaces -from ._aliases import matmul, matrix_transpose, tensordot, vecdot # noqa: F401 +__all__ = clone_module("numpy.linalg", globals()) -import numpy as np +# These functions are in both the main and linalg namespaces +from ._aliases import matmul, matrix_transpose, tensordot, vecdot # noqa: F401 +from ._typing import Array cross = get_xp(np)(_linalg.cross) outer = get_xp(np)(_linalg.outer) EighResult = _linalg.EighResult +EigResult = _linalg.EigResult QRResult = _linalg.QRResult SlogdetResult = _linalg.SlogdetResult SVDResult = _linalg.SVDResult @@ -38,19 +45,28 @@ # To workaround this, the below is the code from np.linalg.solve except # only calling solve1 in the exactly 1D case. + # This code is here instead of in common because it is numpy specific. Also # note that CuPy's solve() does not currently support broadcasting (see # https://github.com/cupy/cupy/blob/main/cupy/cublas.py#L43). -def solve(x1: _np.ndarray, x2: _np.ndarray, /) -> _np.ndarray: +def solve(x1: Array, x2: Array, /) -> Array: try: - from numpy.linalg._linalg import ( - _makearray, _assert_stacked_2d, _assert_stacked_square, - _commonType, isComplexType, _raise_linalgerror_singular + from numpy.linalg._linalg import ( # type: ignore[attr-defined] + _assert_stacked_2d, + _assert_stacked_square, + _commonType, + _makearray, + _raise_linalgerror_singular, + isComplexType, ) except ImportError: - from numpy.linalg.linalg import ( - _makearray, _assert_stacked_2d, _assert_stacked_square, - _commonType, isComplexType, _raise_linalgerror_singular + from numpy.linalg.linalg import ( # type: ignore[attr-defined] + _assert_stacked_2d, + _assert_stacked_square, + _commonType, + _makearray, + _raise_linalgerror_singular, + isComplexType, ) from numpy.linalg import _umath_linalg @@ -61,6 +77,7 @@ def solve(x1: _np.ndarray, x2: _np.ndarray, /) -> _np.ndarray: t, result_t = _commonType(x1, x2) # This part is different from np.linalg.solve + gufunc: np.ufunc if x2.ndim == 1: gufunc = _umath_linalg.solve1 else: @@ -68,23 +85,124 @@ def solve(x1: _np.ndarray, x2: _np.ndarray, /) -> _np.ndarray: # This does nothing currently but is left in because it will be relevant # when complex dtype support is added to the spec in 2022. - signature = 'DD->D' if isComplexType(t) else 'dd->d' - with _np.errstate(call=_raise_linalgerror_singular, invalid='call', - over='ignore', divide='ignore', under='ignore'): - r = gufunc(x1, x2, signature=signature) + signature = "DD->D" if isComplexType(t) else "dd->d" + with np.errstate( + call=_raise_linalgerror_singular, + invalid="call", + over="ignore", + divide="ignore", + under="ignore", + ): + r: Array = gufunc(x1, x2, signature=signature) return wrap(r.astype(result_t, copy=False)) + +# Unlike numpy.linalg.eig, Array API version always returns complex results + +def eig(x: Array, /) -> tuple[Array, Array]: + try: + from numpy.linalg._linalg import ( # type: ignore[attr-defined] + _assert_stacked_square, + _assert_finite, + _commonType, + _makearray, + _raise_linalgerror_eigenvalues_nonconvergence, + isComplexType, + _complexType, + ) + except ImportError: + from numpy.linalg.linalg import ( # type: ignore[attr-defined] + _assert_stacked_square, + _assert_finite, + _commonType, + _makearray, + _raise_linalgerror_eigenvalues_nonconvergence, + isComplexType, + _complexType, + ) + from numpy.linalg import _umath_linalg + + x, wrap = _makearray(x) + _assert_stacked_square(x) + _assert_finite(x) + t, result_t = _commonType(x) + + signature = 'D->DD' if isComplexType(t) else 'd->DD' + with np.errstate(call=_raise_linalgerror_eigenvalues_nonconvergence, + invalid='call', over='ignore', divide='ignore', + under='ignore'): + w, vt = _umath_linalg.eig(x, signature=signature) + + result_t = _complexType(result_t) + vt = vt.astype(result_t, copy=False) + return EigResult(w.astype(result_t, copy=False), wrap(vt)) + + +def eigvals(x: Array, /) -> Array: + try: + from numpy.linalg._linalg import ( # type: ignore[attr-defined] + _assert_stacked_square, + _assert_finite, + _commonType, + _makearray, + _raise_linalgerror_eigenvalues_nonconvergence, + isComplexType, + _complexType, + ) + except ImportError: + from numpy.linalg.linalg import ( # type: ignore[attr-defined] + _assert_stacked_square, + _assert_finite, + _commonType, + _makearray, + _raise_linalgerror_eigenvalues_nonconvergence, + isComplexType, + _complexType, + ) + from numpy.linalg import _umath_linalg + + x, wrap = _makearray(x) + _assert_stacked_square(x) + _assert_finite(x) + t, result_t = _commonType(x) + + signature = 'D->D' if isComplexType(t) else 'd->D' + with np.errstate(call=_raise_linalgerror_eigenvalues_nonconvergence, + invalid='call', over='ignore', divide='ignore', + under='ignore'): + w = _umath_linalg.eigvals(x, signature=signature) + + result_t = _complexType(result_t) + return w.astype(result_t, copy=False) + + # These functions are completely new here. If the library already has them # (i.e., numpy 2.0), use the library version instead of our wrapper. -if hasattr(np.linalg, 'vector_norm'): +if hasattr(np.linalg, "vector_norm"): vector_norm = np.linalg.vector_norm else: vector_norm = get_xp(np)(_linalg.vector_norm) -__all__ = linalg_all + _linalg.__all__ + ['solve'] -del get_xp -del np -del linalg_all -del _linalg +_all = [ + "LinAlgError", + "cond", + "det", + "eig", + "eigvals", + "eigvalsh", + "inv", + "lstsq", + "matrix_power", + "multi_dot", + "norm", + "solve", + "tensorinv", + "tensorsolve", + "vector_norm", +] +__all__ = sorted(set(__all__) | set(_linalg.__all__) | set(_all)) + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/py.typed b/array_api_compat/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/array_api_compat/torch/__init__.py b/array_api_compat/torch/__init__.py index a985986e..c5c801aa 100644 --- a/array_api_compat/torch/__init__.py +++ b/array_api_compat/torch/__init__.py @@ -1,24 +1,25 @@ -from torch import * # noqa: F403 +from typing import Final -# Several names are not included in the above import * -import torch -for n in dir(torch): - if (n.startswith('_') - or n.endswith('_') - or 'cuda' in n - or 'cpu' in n - or 'backward' in n): - continue - exec(n + ' = torch.' + n) +from .._internal import clone_module + +__all__ = clone_module("torch", globals()) # These imports may overwrite names from the import * above. +from . import _aliases from ._aliases import * # noqa: F403 +from ._info import __array_namespace_info__ # noqa: F401 # See the comment in the numpy __init__.py -__import__(__package__ + '.linalg') +__import__(__spec__.parent + '.linalg') +__import__(__spec__.parent + '.fft') -__import__(__package__ + '.fft') +__array_api_version__: Final = '2025.12' -from ..common._helpers import * # noqa: F403 +__all__ = sorted( + set(__all__) + | set(_aliases.__all__) + | {"__array_api_version__", "__array_namespace_info__", "linalg", "fft"} +) -__array_api_version__ = '2024.12' +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index b4786320..e27c3ca2 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -1,22 +1,17 @@ from __future__ import annotations -from functools import wraps as _wraps +from collections.abc import Sequence +from functools import reduce as _reduce, wraps as _wraps from builtins import all as _builtin_all, any as _builtin_any - -from ..common import _aliases -from .._internal import get_xp - -from ._info import __array_namespace_info__ +from typing import Any, Literal +import math import torch -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from typing import List, Optional, Sequence, Tuple, Union - from ..common._typing import Device - from torch import dtype as Dtype - - array = torch.Tensor +from .._internal import get_xp +from ..common import _aliases +from ..common._typing import NestedSequence, SupportsBufferProtocol +from ._typing import Array, Device, DType _int_dtypes = { torch.uint8, @@ -41,47 +36,23 @@ torch.complex128, } -_promotion_table = { - # bool - (torch.bool, torch.bool): torch.bool, +_promotion_table = { # ints - (torch.int8, torch.int8): torch.int8, (torch.int8, torch.int16): torch.int16, (torch.int8, torch.int32): torch.int32, (torch.int8, torch.int64): torch.int64, - (torch.int16, torch.int8): torch.int16, - (torch.int16, torch.int16): torch.int16, (torch.int16, torch.int32): torch.int32, (torch.int16, torch.int64): torch.int64, - (torch.int32, torch.int8): torch.int32, - (torch.int32, torch.int16): torch.int32, - (torch.int32, torch.int32): torch.int32, (torch.int32, torch.int64): torch.int64, - (torch.int64, torch.int8): torch.int64, - (torch.int64, torch.int16): torch.int64, - (torch.int64, torch.int32): torch.int64, - (torch.int64, torch.int64): torch.int64, - # uints - (torch.uint8, torch.uint8): torch.uint8, # ints and uints (mixed sign) - (torch.int8, torch.uint8): torch.int16, - (torch.int16, torch.uint8): torch.int16, - (torch.int32, torch.uint8): torch.int32, - (torch.int64, torch.uint8): torch.int64, (torch.uint8, torch.int8): torch.int16, (torch.uint8, torch.int16): torch.int16, (torch.uint8, torch.int32): torch.int32, (torch.uint8, torch.int64): torch.int64, # floats - (torch.float32, torch.float32): torch.float32, (torch.float32, torch.float64): torch.float64, - (torch.float64, torch.float32): torch.float64, - (torch.float64, torch.float64): torch.float64, # complexes - (torch.complex64, torch.complex64): torch.complex64, (torch.complex64, torch.complex128): torch.complex128, - (torch.complex128, torch.complex64): torch.complex128, - (torch.complex128, torch.complex128): torch.complex128, # Mixed float and complex (torch.float32, torch.complex64): torch.complex64, (torch.float32, torch.complex128): torch.complex128, @@ -89,6 +60,9 @@ (torch.float64, torch.complex128): torch.complex128, } +_promotion_table.update({(b, a): c for (a, b), c in _promotion_table.items()}) +_promotion_table.update({(a, a): a for a in _array_api_dtypes}) + def _two_arg(f): @_wraps(f) @@ -123,26 +97,46 @@ def _fix_promotion(x1, x2, only_scalar=True): _py_scalars = (bool, int, float, complex) -def result_type(*arrays_and_dtypes: Union[array, Dtype, bool, int, float, complex]) -> Dtype: - if len(arrays_and_dtypes) == 0: - raise TypeError("At least one array or dtype must be provided") - if len(arrays_and_dtypes) == 1: +def result_type(*arrays_and_dtypes: Array | DType | complex) -> DType: + num = len(arrays_and_dtypes) + + if num == 0: + raise ValueError("At least one array or dtype must be provided") + + elif num == 1: x = arrays_and_dtypes[0] if isinstance(x, torch.dtype): return x return x.dtype - if len(arrays_and_dtypes) > 2: - return result_type(arrays_and_dtypes[0], result_type(*arrays_and_dtypes[1:])) - x, y = arrays_and_dtypes - if isinstance(x, _py_scalars) or isinstance(y, _py_scalars): - return torch.result_type(x, y) + if num == 2: + x, y = arrays_and_dtypes + return _result_type(x, y) + + else: + # sort scalars so that they are treated last + scalars, others = [], [] + for x in arrays_and_dtypes: + if isinstance(x, _py_scalars): + scalars.append(x) + else: + others.append(x) + if not others: + raise ValueError("At least one array or dtype must be provided") + + # combine left-to-right + return _reduce(_result_type, others + scalars) - xdt = x.dtype if not isinstance(x, torch.dtype) else x - ydt = y.dtype if not isinstance(y, torch.dtype) else y - if (xdt, ydt) in _promotion_table: - return _promotion_table[xdt, ydt] +def _result_type(x: Array | DType | complex, y: Array | DType | complex) -> DType: + if not (isinstance(x, _py_scalars) or isinstance(y, _py_scalars)): + xdt = x if isinstance(x, torch.dtype) else x.dtype + ydt = y if isinstance(y, torch.dtype) else y.dtype + + try: + return _promotion_table[xdt, ydt] + except KeyError: + pass # This doesn't result_type(dtype, dtype) for non-array API dtypes # because torch.result_type only accepts tensors. This does however, allow @@ -151,7 +145,8 @@ def result_type(*arrays_and_dtypes: Union[array, Dtype, bool, int, float, comple y = torch.tensor([], dtype=y) if isinstance(y, torch.dtype) else y return torch.result_type(x, y) -def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool: + +def can_cast(from_: DType | Array, to: DType, /) -> bool: if not isinstance(from_, torch.dtype): from_ = from_.dtype return torch.can_cast(from_, to) @@ -193,32 +188,73 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool: remainder = _two_arg(torch.remainder) subtract = _two_arg(torch.subtract) + +def asarray( + obj: Array | complex | NestedSequence[complex] | SupportsBufferProtocol, + /, + *, + dtype: DType | None = None, + device: Device | None = None, + copy: bool | None = None, + **kwargs: Any, +) -> Array: + # torch.asarray does not respect input->output device propagation + # https://github.com/pytorch/pytorch/issues/150199 + if device is None and isinstance(obj, torch.Tensor): + device = obj.device + return torch.asarray(obj, dtype=dtype, device=device, copy=copy, **kwargs) + + # These wrappers are mostly based on the fact that pytorch uses 'dim' instead # of 'axis'. # torch.min and torch.max return a tuple and don't support multiple axes https://github.com/pytorch/pytorch/issues/58745 -def max(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array: +def max(x: Array, /, *, axis: int | tuple[int, ...] | None = None, keepdims: bool = False) -> Array: # https://github.com/pytorch/pytorch/issues/29137 if axis == (): return torch.clone(x) return torch.amax(x, axis, keepdims=keepdims) -def min(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array: +def min(x: Array, /, *, axis: int | tuple[int, ...] |None = None, keepdims: bool = False) -> Array: # https://github.com/pytorch/pytorch/issues/29137 if axis == (): return torch.clone(x) return torch.amin(x, axis, keepdims=keepdims) -clip = get_xp(torch)(_aliases.clip) unstack = get_xp(torch)(_aliases.unstack) cumulative_sum = get_xp(torch)(_aliases.cumulative_sum) cumulative_prod = get_xp(torch)(_aliases.cumulative_prod) +finfo = get_xp(torch)(_aliases.finfo) +iinfo = get_xp(torch)(_aliases.iinfo) + # torch.sort also returns a tuple # https://github.com/pytorch/pytorch/issues/70921 -def sort(x: array, /, *, axis: int = -1, descending: bool = False, stable: bool = True, **kwargs) -> array: +def sort( + x: Array, + /, + *, + axis: int = -1, + descending: bool = False, + stable: bool = True, + **kwargs: object, +) -> Array: return torch.sort(x, dim=axis, descending=descending, stable=stable, **kwargs).values + +# Wrap torch.argsort to set stable=True by default +def argsort( + x: Array, + /, + *, + axis: int = -1, + descending: bool = False, + stable: bool = True, + **kwargs: object, +) -> Array: + return torch.argsort(x, dim=axis, descending=descending, stable=stable, **kwargs) + + def _normalize_axes(axis, ndim): axes = [] if ndim == 0 and axis: @@ -242,7 +278,7 @@ def _axis_none_keepdims(x, ndim, keepdims): # (https://github.com/pytorch/pytorch/issues/71209) # Note that this is only valid for the axis=None case. if keepdims: - for i in range(ndim): + for _ in range(ndim): x = torch.unsqueeze(x, 0) return x @@ -261,28 +297,35 @@ def _reduce_multiple_axes(f, x, axis, keepdims=False, **kwargs): out = torch.unsqueeze(out, a) return out -def prod(x: array, + +def _sum_prod_no_axis(x: Array, dtype: DType | None) -> Array: + """ + Implements `sum(..., axis=())` and `prod(..., axis=())`. + + Works around https://github.com/pytorch/pytorch/issues/29137 + """ + if dtype is not None: + return x.clone() if dtype == x.dtype else x.to(dtype) + + # We can't upcast uint8 according to the spec because there is no + # torch.uint64, so at least upcast to int64 which is what prod does + # when axis=None. + if x.dtype in (torch.uint8, torch.int8, torch.int16, torch.int32): + return x.to(torch.int64) + + return x.clone() + + +def prod(x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - dtype: Optional[Dtype] = None, + axis: int | tuple[int, ...] | None = None, + dtype: DType | None = None, keepdims: bool = False, - **kwargs) -> array: - x = torch.asarray(x) - ndim = x.ndim + **kwargs: object) -> Array: - # https://github.com/pytorch/pytorch/issues/29137. Separate from the logic - # below because it still needs to upcast. if axis == (): - if dtype is None: - # We can't upcast uint8 according to the spec because there is no - # torch.uint64, so at least upcast to int64 which is what sum does - # when axis=None. - if x.dtype in [torch.int8, torch.int16, torch.int32, torch.uint8]: - return x.to(torch.int64) - return x.clone() - return x.to(dtype) - + return _sum_prod_no_axis(x, dtype) # torch.prod doesn't support multiple axes # (https://github.com/pytorch/pytorch/issues/56586). if isinstance(axis, tuple): @@ -291,51 +334,38 @@ def prod(x: array, # torch doesn't support keepdims with axis=None # (https://github.com/pytorch/pytorch/issues/71209) res = torch.prod(x, dtype=dtype, **kwargs) - res = _axis_none_keepdims(res, ndim, keepdims) + res = _axis_none_keepdims(res, x.ndim, keepdims) return res return torch.prod(x, axis, dtype=dtype, keepdims=keepdims, **kwargs) -def sum(x: array, +def sum(x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - dtype: Optional[Dtype] = None, + axis: int | tuple[int, ...] | None = None, + dtype: DType | None = None, keepdims: bool = False, - **kwargs) -> array: - x = torch.asarray(x) - ndim = x.ndim + **kwargs: object) -> Array: - # https://github.com/pytorch/pytorch/issues/29137. - # Make sure it upcasts. if axis == (): - if dtype is None: - # We can't upcast uint8 according to the spec because there is no - # torch.uint64, so at least upcast to int64 which is what sum does - # when axis=None. - if x.dtype in [torch.int8, torch.int16, torch.int32, torch.uint8]: - return x.to(torch.int64) - return x.clone() - return x.to(dtype) - + return _sum_prod_no_axis(x, dtype) if axis is None: # torch doesn't support keepdims with axis=None # (https://github.com/pytorch/pytorch/issues/71209) res = torch.sum(x, dtype=dtype, **kwargs) - res = _axis_none_keepdims(res, ndim, keepdims) + res = _axis_none_keepdims(res, x.ndim, keepdims) return res return torch.sum(x, axis, dtype=dtype, keepdims=keepdims, **kwargs) -def any(x: array, +def any(x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, + axis: int | tuple[int, ...] | None = None, keepdims: bool = False, - **kwargs) -> array: - x = torch.asarray(x) - ndim = x.ndim + **kwargs: object) -> Array: + if axis == (): return x.to(torch.bool) # torch.any doesn't support multiple axes @@ -347,20 +377,19 @@ def any(x: array, # torch doesn't support keepdims with axis=None # (https://github.com/pytorch/pytorch/issues/71209) res = torch.any(x, **kwargs) - res = _axis_none_keepdims(res, ndim, keepdims) + res = _axis_none_keepdims(res, x.ndim, keepdims) return res.to(torch.bool) # torch.any doesn't return bool for uint8 return torch.any(x, axis, keepdims=keepdims).to(torch.bool) -def all(x: array, +def all(x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, + axis: int | tuple[int, ...] | None = None, keepdims: bool = False, - **kwargs) -> array: - x = torch.asarray(x) - ndim = x.ndim + **kwargs: object) -> Array: + if axis == (): return x.to(torch.bool) # torch.all doesn't support multiple axes @@ -372,18 +401,18 @@ def all(x: array, # torch doesn't support keepdims with axis=None # (https://github.com/pytorch/pytorch/issues/71209) res = torch.all(x, **kwargs) - res = _axis_none_keepdims(res, ndim, keepdims) + res = _axis_none_keepdims(res, x.ndim, keepdims) return res.to(torch.bool) # torch.all doesn't return bool for uint8 return torch.all(x, axis, keepdims=keepdims).to(torch.bool) -def mean(x: array, +def mean(x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, + axis: int | tuple[int, ...] | None = None, keepdims: bool = False, - **kwargs) -> array: + **kwargs: object) -> Array: # https://github.com/pytorch/pytorch/issues/29137 if axis == (): return torch.clone(x) @@ -395,13 +424,13 @@ def mean(x: array, return res return torch.mean(x, axis, keepdims=keepdims, **kwargs) -def std(x: array, +def std(x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - correction: Union[int, float] = 0.0, + axis: int | tuple[int, ...] | None = None, + correction: float = 0.0, keepdims: bool = False, - **kwargs) -> array: + **kwargs: object) -> Array: # Note, float correction is not supported # https://github.com/pytorch/pytorch/issues/61492. We don't try to # implement it here for now. @@ -426,13 +455,13 @@ def std(x: array, return res return torch.std(x, axis, correction=_correction, keepdims=keepdims, **kwargs) -def var(x: array, +def var(x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - correction: Union[int, float] = 0.0, + axis: int | tuple[int, ...] | None = None, + correction: float = 0.0, keepdims: bool = False, - **kwargs) -> array: + **kwargs: object) -> Array: # Note, float correction is not supported # https://github.com/pytorch/pytorch/issues/61492. We don't try to # implement it here for now. @@ -455,11 +484,11 @@ def var(x: array, # torch.concat doesn't support dim=None # https://github.com/pytorch/pytorch/issues/70925 -def concat(arrays: Union[Tuple[array, ...], List[array]], +def concat(arrays: tuple[Array, ...] | list[Array], /, *, - axis: Optional[int] = 0, - **kwargs) -> array: + axis: int | None = 0, + **kwargs: object) -> Array: if axis is None: arrays = tuple(ar.flatten() for ar in arrays) axis = 0 @@ -468,7 +497,7 @@ def concat(arrays: Union[Tuple[array, ...], List[array]], # torch.squeeze only accepts int dim and doesn't require it # https://github.com/pytorch/pytorch/issues/70924. Support for tuple dim was # added at https://github.com/pytorch/pytorch/pull/89017. -def squeeze(x: array, /, axis: Union[int, Tuple[int, ...]]) -> array: +def squeeze(x: Array, /, axis: int | tuple[int, ...]) -> Array: if isinstance(axis, int): axis = (axis,) for a in axis: @@ -482,27 +511,27 @@ def squeeze(x: array, /, axis: Union[int, Tuple[int, ...]]) -> array: return x # torch.broadcast_to uses size instead of shape -def broadcast_to(x: array, /, shape: Tuple[int, ...], **kwargs) -> array: +def broadcast_to(x: Array, /, shape: tuple[int, ...], **kwargs: object) -> Array: return torch.broadcast_to(x, shape, **kwargs) # torch.permute uses dims instead of axes -def permute_dims(x: array, /, axes: Tuple[int, ...]) -> array: +def permute_dims(x: Array, /, axes: tuple[int, ...]) -> Array: return torch.permute(x, axes) # The axis parameter doesn't work for flip() and roll() # https://github.com/pytorch/pytorch/issues/71210. Also torch.flip() doesn't # accept axis=None -def flip(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs) -> array: +def flip(x: Array, /, *, axis: int | tuple[int, ...] | None = None, **kwargs: object) -> Array: if axis is None: axis = tuple(range(x.ndim)) # torch.flip doesn't accept dim as an int but the method does # https://github.com/pytorch/pytorch/issues/18095 return x.flip(axis, **kwargs) -def roll(x: array, /, shift: Union[int, Tuple[int, ...]], *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs) -> array: +def roll(x: Array, /, shift: int | tuple[int, ...], *, axis: int | tuple[int, ...] | None = None, **kwargs: object) -> Array: return torch.roll(x, shift, axis, **kwargs) -def nonzero(x: array, /, **kwargs) -> Tuple[array, ...]: +def nonzero(x: Array, /, **kwargs: object) -> tuple[Array, ...]: if x.ndim == 0: raise ValueError("nonzero() does not support zero-dimensional arrays") return torch.nonzero(x, as_tuple=True, **kwargs) @@ -510,45 +539,59 @@ def nonzero(x: array, /, **kwargs) -> Tuple[array, ...]: # torch uses `dim` instead of `axis` def diff( - x: array, + x: Array, /, *, axis: int = -1, n: int = 1, - prepend: Optional[array] = None, - append: Optional[array] = None, -) -> array: + prepend: Array | None = None, + append: Array | None = None, +) -> Array: return torch.diff(x, dim=axis, n=n, prepend=prepend, append=append) # torch uses `dim` instead of `axis`, does not have keepdims def count_nonzero( - x: array, + x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, + axis: int | tuple[int, ...] | None = None, keepdims: bool = False, -) -> array: +) -> Array: result = torch.count_nonzero(x, dim=axis) if keepdims: - if axis is not None: + if isinstance(axis, int): return result.unsqueeze(axis) + elif isinstance(axis, tuple): + n_axis = [x.ndim + ax if ax < 0 else ax for ax in axis] + sh = [1 if i in n_axis else x.shape[i] for i in range(x.ndim)] + return torch.reshape(result, sh) return _axis_none_keepdims(result, x.ndim, keepdims) else: return result +# "repeat" is torch.repeat_interleave; also the dim argument +def repeat(x: Array, repeats: int | Array, /, *, axis: int | None = None) -> Array: + if isinstance(repeats, torch.Tensor) and repeats.dtype in (torch.int8, torch.int16): + # torch rejects short integers for the `repeat` argument: + # https://github.com/pytorch/pytorch/issues/151311 + repeats = repeats.to(torch.int32) + return torch.repeat_interleave(x, repeats, axis) -def where(condition: array, x1: array, x2: array, /) -> array: + +def where(condition: Array, x1: Array | complex, x2: Array | complex, /) -> Array: x1, x2 = _fix_promotion(x1, x2) return torch.where(condition, x1, x2) + # torch.reshape doesn't have the copy keyword -def reshape(x: array, +def reshape(x: Array, /, - shape: Tuple[int, ...], - copy: Optional[bool] = None, - **kwargs) -> array: + shape: tuple[int, ...], + *, + copy: bool | None = None, + **kwargs: object) -> Array: if copy is not None: raise NotImplementedError("torch.reshape doesn't yet support the copy keyword") return torch.reshape(x, shape, **kwargs) @@ -557,14 +600,14 @@ def reshape(x: array, # (https://github.com/pytorch/pytorch/issues/70915), and doesn't support some # keyword argument combinations # (https://github.com/pytorch/pytorch/issues/70914) -def arange(start: Union[int, float], +def arange(start: float, /, - stop: Optional[Union[int, float]] = None, - step: Union[int, float] = 1, + stop: float | None = None, + step: float = 1, *, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, - **kwargs) -> array: + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object) -> Array: if stop is None: start, stop = 0, start if step > 0 and stop <= start or step < 0 and stop >= start: @@ -573,19 +616,23 @@ def arange(start: Union[int, float], dtype = torch.int64 else: dtype = torch.float32 - return torch.empty(0, dtype=dtype, device=device, **kwargs) - return torch.arange(start, stop, step, dtype=dtype, device=device, **kwargs) + return torch.empty(0, device=device, **kwargs).to(dtype) + try: + return torch.arange(start, stop, step, dtype=dtype, device=device, **kwargs) + # torch 2.7 raises RuntimeError, 2.9 emits NotImplementedError + except (NotImplementedError, RuntimeError): + return torch.arange(start, stop, step, device=device, **kwargs).to(dtype) # torch.eye does not accept None as a default for the second argument and # doesn't support off-diagonals (https://github.com/pytorch/pytorch/issues/70910) def eye(n_rows: int, - n_cols: Optional[int] = None, + n_cols: int | None = None, /, *, k: int = 0, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, - **kwargs) -> array: + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object) -> Array: if n_cols is None: n_cols = n_rows z = torch.zeros(n_rows, n_cols, dtype=dtype, device=device, **kwargs) @@ -594,83 +641,98 @@ def eye(n_rows: int, return z # torch.linspace doesn't have the endpoint parameter -def linspace(start: Union[int, float], - stop: Union[int, float], +def linspace(start: float, + stop: float, /, num: int, *, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, + dtype: DType | None = None, + device: Device | None = None, endpoint: bool = True, - **kwargs) -> array: + **kwargs: object) -> Array: if not endpoint: return torch.linspace(start, stop, num+1, dtype=dtype, device=device, **kwargs)[:-1] return torch.linspace(start, stop, num, dtype=dtype, device=device, **kwargs) # torch.full does not accept an int size # https://github.com/pytorch/pytorch/issues/70906 -def full(shape: Union[int, Tuple[int, ...]], - fill_value: Union[bool, int, float, complex], +def full(shape: int | tuple[int, ...], + fill_value: complex, *, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, - **kwargs) -> array: + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object) -> Array: if isinstance(shape, int): shape = (shape,) return torch.full(shape, fill_value, dtype=dtype, device=device, **kwargs) # ones, zeros, and empty do not accept shape as a keyword argument -def ones(shape: Union[int, Tuple[int, ...]], +def ones(shape: int | tuple[int, ...], *, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, - **kwargs) -> array: + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object) -> Array: return torch.ones(shape, dtype=dtype, device=device, **kwargs) -def zeros(shape: Union[int, Tuple[int, ...]], +def zeros(shape: int | tuple[int, ...], *, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, - **kwargs) -> array: + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object) -> Array: return torch.zeros(shape, dtype=dtype, device=device, **kwargs) -def empty(shape: Union[int, Tuple[int, ...]], +def empty(shape: int | tuple[int, ...], *, - dtype: Optional[Dtype] = None, - device: Optional[Device] = None, - **kwargs) -> array: + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object) -> Array: return torch.empty(shape, dtype=dtype, device=device, **kwargs) # tril and triu do not call the keyword argument k -def tril(x: array, /, *, k: int = 0) -> array: +def tril(x: Array, /, *, k: int = 0) -> Array: return torch.tril(x, k) -def triu(x: array, /, *, k: int = 0) -> array: +def triu(x: Array, /, *, k: int = 0) -> Array: return torch.triu(x, k) # Functions that aren't in torch https://github.com/pytorch/pytorch/issues/58742 -def expand_dims(x: array, /, *, axis: int = 0) -> array: - return torch.unsqueeze(x, axis) +def expand_dims(x: Array, /, axis: int | tuple[int, ...]) -> Array: + if isinstance(axis, int): + return torch.unsqueeze(x, axis) + else: + # follow https://github.com/numpy/numpy/blob/maintenance/2.4.x/numpy/lib/_shape_base_impl.py#L596-L602 + y_ndim = x.ndim + len(axis) + # normalize + n_axis = tuple(ax + y_ndim if ax < 0 else ax for ax in axis) + if (len(n_axis) != len(set(n_axis)) or + _builtin_any(ax < 0 or ax >= y_ndim for ax in n_axis) + ): + raise ValueError(f"{axis=} not allowed for {x.shape = }") + + shape_it = iter(x.shape) + shape = [1 if ax in n_axis else next(shape_it) for ax in range(y_ndim)] + + return torch.reshape(x, shape) def astype( - x: array, - dtype: Dtype, + x: Array, + dtype: DType, /, *, copy: bool = True, - device: Optional[Device] = None, -) -> array: + device: Device | None = None, +) -> Array: if device is not None: return x.to(device, dtype=dtype, copy=copy) return x.to(dtype=dtype, copy=copy) -def broadcast_arrays(*arrays: array) -> List[array]: +def broadcast_arrays(*arrays: Array) -> tuple[Array, ...]: shape = torch.broadcast_shapes(*[a.shape for a in arrays]) - return [torch.broadcast_to(a, shape) for a in arrays] + return tuple(torch.broadcast_to(a, shape) for a in arrays) # Note that these named tuples aren't actually part of the standard namespace, # but I don't see any issue with exporting the names here regardless. @@ -678,7 +740,7 @@ def broadcast_arrays(*arrays: array) -> List[array]: UniqueInverseResult) # https://github.com/pytorch/pytorch/issues/70920 -def unique_all(x: array) -> UniqueAllResult: +def unique_all(x: Array) -> UniqueAllResult: # torch.unique doesn't support returning indices. # https://github.com/pytorch/pytorch/issues/36748. The workaround # suggested in that issue doesn't actually function correctly (it relies @@ -691,7 +753,7 @@ def unique_all(x: array) -> UniqueAllResult: # counts[torch.isnan(values)] = 1 # return UniqueAllResult(values, indices, inverse_indices, counts) -def unique_counts(x: array) -> UniqueCountsResult: +def unique_counts(x: Array) -> UniqueCountsResult: values, counts = torch.unique(x, return_counts=True) # torch.unique incorrectly gives a 0 count for nan values. @@ -699,14 +761,14 @@ def unique_counts(x: array) -> UniqueCountsResult: counts[torch.isnan(values)] = 1 return UniqueCountsResult(values, counts) -def unique_inverse(x: array) -> UniqueInverseResult: +def unique_inverse(x: Array) -> UniqueInverseResult: values, inverse = torch.unique(x, return_inverse=True) return UniqueInverseResult(values, inverse) -def unique_values(x: array) -> array: +def unique_values(x: Array) -> Array: return torch.unique(x) -def matmul(x1: array, x2: array, /, **kwargs) -> array: +def matmul(x1: Array, x2: Array, /, **kwargs: object) -> Array: # torch.matmul doesn't type promote (but differently from _fix_promotion) x1, x2 = _fix_promotion(x1, x2, only_scalar=False) return torch.matmul(x1, x2, **kwargs) @@ -714,12 +776,19 @@ def matmul(x1: array, x2: array, /, **kwargs) -> array: matrix_transpose = get_xp(torch)(_aliases.matrix_transpose) _vecdot = get_xp(torch)(_aliases.vecdot) -def vecdot(x1: array, x2: array, /, *, axis: int = -1) -> array: +def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: x1, x2 = _fix_promotion(x1, x2, only_scalar=False) return _vecdot(x1, x2, axis=axis) # torch.tensordot uses dims instead of axes -def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2, **kwargs) -> array: +def tensordot( + x1: Array, + x2: Array, + /, + *, + axes: int | tuple[Sequence[int], Sequence[int]] = 2, + **kwargs: object, +) -> Array: # Note: torch.tensordot fails with integer dtypes when there is only 1 # element in the axis (https://github.com/pytorch/pytorch/issues/84530). x1, x2 = _fix_promotion(x1, x2, only_scalar=False) @@ -727,8 +796,10 @@ def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int], def isdtype( - dtype: Dtype, kind: Union[Dtype, str, Tuple[Union[Dtype, str], ...]], - *, _tuple=True, # Disallow nested tuples + dtype: DType, + kind: DType | str | tuple[DType | str, ...], + *, + _tuple: bool = True, # Disallow nested tuples ) -> bool: """ Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``. @@ -762,19 +833,75 @@ def isdtype( else: return dtype == kind -def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -> array: +def take(x: Array, indices: Array, /, *, axis: int | None = None, **kwargs: object) -> Array: if axis is None: if x.ndim != 1: raise ValueError("axis must be specified when ndim > 1") axis = 0 - return torch.index_select(x, axis, indices, **kwargs) + # torch does not support negative indices, + # see https://github.com/pytorch/pytorch/issues/146211 + return torch.index_select( + x, + axis, + torch.where(indices < 0, indices + x.shape[axis], indices), + **kwargs + ) + + +def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array: + # torch does not support negative indices, + # see https://github.com/pytorch/pytorch/issues/146211 + return torch.take_along_dim( + x, + torch.where(indices < 0, indices + x.shape[axis], indices), + dim=axis + ) + + +def clip( + x: Array, + /, + min: int | float | Array | None = None, + max: int | float | Array | None = None, + **kwargs +) -> Array: + def _isscalar(a: object): + return isinstance(a, int | float) or a is None + + # cf clip in common/_aliases.py + if not x.is_floating_point(): + if type(min) is int and min <= torch.iinfo(x.dtype).min: + min = None + if type(max) is int and max >= torch.iinfo(x.dtype).max: + max = None + + if min is None and max is None: + return torch.clone(x) + + min_is_scalar = _isscalar(min) + max_is_scalar = _isscalar(max) + + if min_is_scalar and max_is_scalar: + if (min is not None and math.isnan(min)) or (max is not None and math.isnan(max)): + # edge case: torch.clamp(torch.zeros(1), float('nan')) -> tensor(0.) + # https://github.com/pytorch/pytorch/issues/172067 + return torch.full_like(x, fill_value=torch.nan) + return torch.clamp(x, min, max, **kwargs) + + # pytorch has (tensor, tensor, tensor) and (tensor, scalar, scalar) signatures, + # but does not accept (tensor, scalar, tensor) + a_min = min + if min is not None and min_is_scalar: + a_min = torch.as_tensor(min, dtype=x.dtype, device=x.device) + a_max = max + if max is not None and max_is_scalar: + a_max = torch.as_tensor(max, dtype=x.dtype, device=x.device) -def take_along_axis(x: array, indices: array, /, *, axis: int = -1) -> array: - return torch.take_along_dim(x, indices, dim=axis) + return torch.clamp(x, a_min, a_max, **kwargs) -def sign(x: array, /) -> array: +def sign(x: Array, /) -> Array: # torch sign() does not support complex numbers and does not propagate # nans. See https://github.com/data-apis/array-api-compat/issues/136 if x.dtype.is_complex: @@ -789,22 +916,45 @@ def sign(x: array, /) -> array: return out -__all__ = ['__array_namespace_info__', 'result_type', 'can_cast', +def round(x: Array, /, **kwargs) -> Array: + # torch.round fails for complex inputs + # https://github.com/pytorch/pytorch/issues/58743#issuecomment-2727603845 + if x.dtype.is_complex: + out = kwargs.pop('out', None) + res_r = torch.round(x.real, **kwargs) + res_i = torch.round(x.imag, **kwargs) + res = res_r + 1j*res_i + if out is not None: + out.copy_(res) + return out + return res + else: + return torch.round(x, **kwargs) + + +def meshgrid(*arrays: Array, indexing: Literal['xy', 'ij'] = 'xy') -> tuple[Array, ...]: + # torch <= 2.9 emits a UserWarning: "torch.meshgrid: in an upcoming release, it + # will be required to pass the indexing argument." + # Thus always pass it explicitly. + if indexing not in ("xy", "ij"): + raise ValueError(f'torch.meshgrid: indexing must be one of "xy" or "ij", but received: {indexing}') + return torch.meshgrid(*arrays, indexing=indexing) if arrays else () + + +__all__ = ['asarray', 'result_type', 'can_cast', 'permute_dims', 'bitwise_invert', 'newaxis', 'conj', 'add', 'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'copysign', 'count_nonzero', - 'diff', 'divide', + 'diff', 'divide', 'round', 'equal', 'floor_divide', 'greater', 'greater_equal', 'hypot', 'less', 'less_equal', 'logaddexp', 'maximum', 'minimum', 'multiply', 'not_equal', 'pow', 'remainder', 'subtract', 'max', - 'min', 'clip', 'unstack', 'cumulative_sum', 'cumulative_prod', 'sort', 'prod', 'sum', - 'any', 'all', 'mean', 'std', 'var', 'concat', 'squeeze', - 'broadcast_to', 'flip', 'roll', 'nonzero', 'where', 'reshape', + 'min', 'clip', 'unstack', 'cumulative_sum', 'cumulative_prod', 'sort', + 'argsort', 'prod', 'sum', 'any', 'all', 'mean', 'std', 'var', 'concat', + 'squeeze', 'broadcast_to', 'flip', 'roll', 'nonzero', 'where', 'reshape', 'arange', 'eye', 'linspace', 'full', 'ones', 'zeros', 'empty', 'tril', 'triu', 'expand_dims', 'astype', 'broadcast_arrays', 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult', 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values', 'matmul', 'matrix_transpose', 'vecdot', 'tensordot', 'isdtype', - 'take', 'take_along_axis', 'sign'] - -_all_ignore = ['torch', 'get_xp'] + 'take', 'take_along_axis', 'sign', 'finfo', 'iinfo', 'repeat', 'meshgrid'] diff --git a/array_api_compat/torch/_info.py b/array_api_compat/torch/_info.py index 34fbcb21..050c7846 100644 --- a/array_api_compat/torch/_info.py +++ b/array_api_compat/torch/_info.py @@ -34,7 +34,7 @@ class __array_namespace_info__: Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.default_dtypes() {'real floating': numpy.float64, 'complex floating': numpy.complex128, @@ -76,16 +76,16 @@ def capabilities(self): Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.capabilities() {'boolean indexing': True, - 'data-dependent shapes': True} + 'data-dependent shapes': True, + 'max dimensions': 64} """ return { "boolean indexing": True, "data-dependent shapes": True, - # 'max rank' will be part of the 2024.12 standard "max dimensions": 64, } @@ -102,15 +102,24 @@ def default_device(self): Returns ------- - device : str + device : Device The default device used for new PyTorch arrays. Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.default_device() - 'cpu' + device(type='cpu') + Notes + ----- + This method returns the static default device when PyTorch is initialized. + However, the *current* device used by creation functions (``empty`` etc.) + can be changed at runtime. + + See Also + -------- + https://github.com/data-apis/array-api/issues/835 """ return torch.device("cpu") @@ -120,9 +129,9 @@ def default_dtypes(self, *, device=None): Parameters ---------- - device : str, optional - The device to get the default data types for. For PyTorch, only - ``'cpu'`` is allowed. + device : Device, optional + The device to get the default data types for. + Unused for PyTorch, as all devices use the same default dtypes. Returns ------- @@ -139,7 +148,7 @@ def default_dtypes(self, *, device=None): Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.default_dtypes() {'real floating': torch.float32, 'complex floating': torch.complex64, @@ -250,8 +259,9 @@ def dtypes(self, *, device=None, kind=None): Parameters ---------- - device : str, optional + device : Device, optional The device to get the data types for. + Unused for PyTorch, as all devices use the same dtypes. kind : str or tuple of str, optional The kind of data types to return. If ``None``, all data types are returned. If a string, only data types of that kind are returned. @@ -287,7 +297,7 @@ def dtypes(self, *, device=None, kind=None): Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.dtypes(kind='signed integer') {'int8': numpy.int8, 'int16': numpy.int16, @@ -310,7 +320,7 @@ def devices(self): Returns ------- - devices : list of str + devices : list[Device] The devices supported by PyTorch. See Also @@ -322,7 +332,7 @@ def devices(self): Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.devices() [device(type='cpu'), device(type='mps', index=0), device(type='meta')] @@ -333,6 +343,7 @@ def devices(self): # device: try: torch.device('notadevice') + raise AssertionError("unreachable") # pragma: nocover except RuntimeError as e: # The error message is something like: # "Expected one of cpu, cuda, ipu, xpu, mkldnn, opengl, opencl, ideep, hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, hpu, mtia, privateuseone device type at start of device string: notadevice" @@ -355,4 +366,4 @@ def devices(self): break i += 1 - return devices + return tuple(devices) diff --git a/array_api_compat/torch/_typing.py b/array_api_compat/torch/_typing.py new file mode 100644 index 00000000..52670871 --- /dev/null +++ b/array_api_compat/torch/_typing.py @@ -0,0 +1,3 @@ +__all__ = ["Array", "Device", "DType"] + +from torch import device as Device, dtype as DType, Tensor as Array diff --git a/array_api_compat/torch/fft.py b/array_api_compat/torch/fft.py index 3c9117ee..0fa6ea9a 100644 --- a/array_api_compat/torch/fft.py +++ b/array_api_compat/torch/fft.py @@ -1,86 +1,82 @@ from __future__ import annotations -from typing import TYPE_CHECKING -if TYPE_CHECKING: - import torch - array = torch.Tensor - from typing import Union, Sequence, Literal +from collections.abc import Sequence +from typing import Literal -from torch.fft import * # noqa: F403 +import torch # noqa: F401 import torch.fft +from ._typing import Array +from .._internal import clone_module + +__all__ = clone_module("torch.fft", globals()) + # Several torch fft functions do not map axes to dim def fftn( - x: array, + x: Array, /, *, - s: Sequence[int] = None, - axes: Sequence[int] = None, + s: Sequence[int] | None = None, + axes: Sequence[int] | None = None, norm: Literal["backward", "ortho", "forward"] = "backward", - **kwargs, -) -> array: + **kwargs: object, +) -> Array: return torch.fft.fftn(x, s=s, dim=axes, norm=norm, **kwargs) def ifftn( - x: array, + x: Array, /, *, - s: Sequence[int] = None, - axes: Sequence[int] = None, + s: Sequence[int] | None = None, + axes: Sequence[int] | None = None, norm: Literal["backward", "ortho", "forward"] = "backward", - **kwargs, -) -> array: + **kwargs: object, +) -> Array: return torch.fft.ifftn(x, s=s, dim=axes, norm=norm, **kwargs) def rfftn( - x: array, + x: Array, /, *, - s: Sequence[int] = None, - axes: Sequence[int] = None, + s: Sequence[int] | None = None, + axes: Sequence[int] | None = None, norm: Literal["backward", "ortho", "forward"] = "backward", - **kwargs, -) -> array: + **kwargs: object, +) -> Array: return torch.fft.rfftn(x, s=s, dim=axes, norm=norm, **kwargs) def irfftn( - x: array, + x: Array, /, *, - s: Sequence[int] = None, - axes: Sequence[int] = None, + s: Sequence[int] | None = None, + axes: Sequence[int] | None = None, norm: Literal["backward", "ortho", "forward"] = "backward", - **kwargs, -) -> array: + **kwargs: object, +) -> Array: return torch.fft.irfftn(x, s=s, dim=axes, norm=norm, **kwargs) def fftshift( - x: array, + x: Array, /, *, - axes: Union[int, Sequence[int]] = None, - **kwargs, -) -> array: + axes: int | Sequence[int] | None = None, + **kwargs: object, +) -> Array: return torch.fft.fftshift(x, dim=axes, **kwargs) def ifftshift( - x: array, + x: Array, /, *, - axes: Union[int, Sequence[int]] = None, - **kwargs, -) -> array: + axes: int | Sequence[int] | None = None, + **kwargs: object, +) -> Array: return torch.fft.ifftshift(x, dim=axes, **kwargs) -__all__ = torch.fft.__all__ + [ - "fftn", - "ifftn", - "rfftn", - "irfftn", - "fftshift", - "ifftshift", -] +__all__ += ["fftn", "ifftn", "rfftn", "irfftn", "fftshift", "ifftshift"] -_all_ignore = ['torch'] +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/torch/linalg.py b/array_api_compat/torch/linalg.py index e26198b9..08271d22 100644 --- a/array_api_compat/torch/linalg.py +++ b/array_api_compat/torch/linalg.py @@ -1,42 +1,35 @@ from __future__ import annotations -from typing import TYPE_CHECKING -if TYPE_CHECKING: - import torch - array = torch.Tensor - from torch import dtype as Dtype - from typing import Optional, Union, Tuple, Literal - inf = float('inf') +import torch +import torch.linalg -from ._aliases import _fix_promotion, sum - -from torch.linalg import * # noqa: F403 +from .._internal import clone_module -# torch.linalg doesn't define __all__ -# from torch.linalg import __all__ as linalg_all -from torch import linalg as torch_linalg -linalg_all = [i for i in dir(torch_linalg) if not i.startswith('_')] +__all__ = clone_module("torch.linalg", globals()) # outer is implemented in torch but aren't in the linalg namespace from torch import outer +from ._aliases import _fix_promotion, sum # These functions are in both the main and linalg namespaces from ._aliases import matmul, matrix_transpose, tensordot +from ._typing import Array, DType +from ..common._typing import JustInt, JustFloat # Note: torch.linalg.cross does not default to axis=-1 (it defaults to the # first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743 # torch.cross also does not support broadcasting when it would add new # dimensions https://github.com/pytorch/pytorch/issues/39656 -def cross(x1: array, x2: array, /, *, axis: int = -1) -> array: +def cross(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: x1, x2 = _fix_promotion(x1, x2, only_scalar=False) if not (-min(x1.ndim, x2.ndim) <= axis < max(x1.ndim, x2.ndim)): raise ValueError(f"axis {axis} out of bounds for cross product of arrays with shapes {x1.shape} and {x2.shape}") if not (x1.shape[axis] == x2.shape[axis] == 3): raise ValueError(f"cross product axis must have size 3, got {x1.shape[axis]} and {x2.shape[axis]}") x1, x2 = torch.broadcast_tensors(x1, x2) - return torch_linalg.cross(x1, x2, dim=axis) + return torch.linalg.cross(x1, x2, dim=axis) -def vecdot(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array: +def vecdot(x1: Array, x2: Array, /, *, axis: int = -1, **kwargs: object) -> Array: from ._aliases import isdtype x1, x2 = _fix_promotion(x1, x2, only_scalar=False) @@ -58,7 +51,7 @@ def vecdot(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array: return res[..., 0, 0] return torch.linalg.vecdot(x1, x2, dim=axis, **kwargs) -def solve(x1: array, x2: array, /, **kwargs) -> array: +def solve(x1: Array, x2: Array, /, **kwargs: object) -> Array: x1, x2 = _fix_promotion(x1, x2, only_scalar=False) # Torch tries to emulate NumPy 1 solve behavior by using batched 1-D solve # whenever @@ -79,19 +72,20 @@ def solve(x1: array, x2: array, /, **kwargs) -> array: return torch.linalg.solve(x1, x2, **kwargs) # torch.trace doesn't support the offset argument and doesn't support stacking -def trace(x: array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> array: +def trace(x: Array, /, *, offset: int = 0, dtype: DType | None = None) -> Array: # Use our wrapped sum to make sure it does upcasting correctly return sum(torch.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype) def vector_norm( - x: array, + x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, + axis: int | tuple[int, ...] | None = None, keepdims: bool = False, - ord: Union[int, float, Literal[inf, -inf]] = 2, - **kwargs, -) -> array: + # JustFloat stands for inf | -inf, which are not valid for Literal + ord: JustInt | JustFloat = 2, + **kwargs: object, +) -> Array: # torch.vector_norm incorrectly treats axis=() the same as axis=None if axis == (): out = kwargs.get('out') @@ -113,9 +107,8 @@ def vector_norm( return out return torch.linalg.vector_norm(x, ord=ord, axis=axis, keepdim=keepdims, **kwargs) -__all__ = linalg_all + ['outer', 'matmul', 'matrix_transpose', 'tensordot', - 'cross', 'vecdot', 'solve', 'trace', 'vector_norm'] - -_all_ignore = ['torch_linalg', 'sum'] +__all__ += ['outer', 'matmul', 'matrix_transpose', 'tensordot', + 'cross', 'vecdot', 'solve', 'trace', 'vector_norm'] -del linalg_all +def __dir__() -> list[str]: + return __all__ diff --git a/cupy-xfails.txt b/cupy-xfails.txt index 63e844cd..a32e382c 100644 --- a/cupy-xfails.txt +++ b/cupy-xfails.txt @@ -11,12 +11,10 @@ array_api_tests/test_array_object.py::test_scalar_casting[__index__(int64)] # testsuite bug (https://github.com/data-apis/array-api-tests/issues/172) array_api_tests/test_array_object.py::test_getitem -# copy=False is not yet implemented -array_api_tests/test_creation_functions.py::test_asarray_arrays - -# finfo test is testing that the result is a float instead of float32 (see -# also https://github.com/data-apis/array-api/issues/405) +# attributes are np.float32 instead of float +# (see also https://github.com/data-apis/array-api/issues/405) array_api_tests/test_data_type_functions.py::test_finfo[float32] +array_api_tests/test_data_type_functions.py::test_finfo[complex64] # Some array attributes are missing, and we do not wrap the array object array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__] @@ -26,6 +24,10 @@ array_api_tests/test_has_names.py::test_has_names[array_attribute-mT] array_api_tests/test_linalg.py::test_solve +# 2025.12 support; {eig,eigvals} are new in CuPy 14 +array_api_tests/test_linalg.py::test_eig +array_api_tests/test_linalg.py::test_eigvals + # We cannot modify array methods array_api_tests/test_operators_and_elementwise_functions.py::test_divide[__truediv__(x, s)] array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x, s)] @@ -36,6 +38,21 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__sub array_api_tests/test_operators_and_elementwise_functions.py::test_add[__add__(x, s)] # floating point inaccuracy array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[remainder(x1, x2)] +# incomplete NEP50 support in CuPy 13.x (fixed in 14.0.0a1) +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[pow] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[logaddexp] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[nextafter] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[hypot] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[atan2] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[divide] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[copysign] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[divide] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[maximum] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[add] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[minimum] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[multiply] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[subtract] + # cupy (arg)min/max wrong with infinities # https://github.com/cupy/cupy/issues/7424 @@ -168,6 +185,18 @@ array_api_tests/test_special_cases.py::test_unary[tan(x_i is -0) -> -0] array_api_tests/test_special_cases.py::test_unary[tanh(x_i is -0) -> -0] array_api_tests/test_special_cases.py::test_unary[trunc(x_i is -0) -> -0] +# complex spec cases +array_api_tests/test_special_cases.py::test_unary[acosh(real(x_i) is +0 and imag(x_i) is NaN) -> NaN \xb1 \u03c0j/2] +array_api_tests/test_special_cases.py::test_unary[log(real(x_i) is -0 and imag(x_i) is +0) -> -infinity + \u03c0j] +array_api_tests/test_special_cases.py::test_unary[log((real(x_i) is +infinity or real(x_i) == -infinity) and imag(x_i) is NaN) -> +infinity + NaN j] +array_api_tests/test_special_cases.py::test_unary[log(real(x_i) is NaN and imag(x_i) is +infinity) -> +infinity + NaN j] +array_api_tests/test_special_cases.py::test_unary[log1p((real(x_i) is +infinity or real(x_i) == -infinity) and imag(x_i) is NaN) -> +infinity + NaN j] +array_api_tests/test_special_cases.py::test_unary[log1p(real(x_i) is NaN and imag(x_i) is +infinity) -> +infinity + NaN j] +array_api_tests/test_special_cases.py::test_unary[sqrt(real(x_i) is +infinity and isfinite(imag(x_i)) and imag(x_i) > 0) -> +0 + infinity j] +array_api_tests/test_special_cases.py::test_unary[tanh(real(x_i) is +0 and imag(x_i) is +infinity) -> +0 + NaN j] +array_api_tests/test_special_cases.py::test_unary[tanh(real(x_i) is +0 and imag(x_i) is NaN) -> +0 + NaN j] +array_api_tests/test_special_cases.py::test_unary[tanh(real(x_i) is +infinity and isfinite(imag(x_i)) and imag(x_i) > 0) -> 1 + 0j] + # CuPy gives the wrong shape for n-dim fft funcs. See # https://github.com/data-apis/array-api-compat/pull/78#issuecomment-1984527870 array_api_tests/test_fft.py::test_fftn @@ -183,12 +212,13 @@ array_api_tests/test_manipulation_functions.py::test_repeat array_api_tests/test_signatures.py::test_func_signature[from_dlpack] array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__] -+# 2024.12 support -array_api_tests/test_signatures.py::test_func_signature[count_nonzero] -array_api_tests/test_signatures.py::test_func_signature[bitwise_and] -array_api_tests/test_signatures.py::test_func_signature[bitwise_left_shift] -array_api_tests/test_signatures.py::test_func_signature[bitwise_or] -array_api_tests/test_signatures.py::test_func_signature[bitwise_right_shift] -array_api_tests/test_signatures.py::test_func_signature[bitwise_xor] +# 2024.12 support array_api_tests/test_special_cases.py::test_binary[nextafter(x1_i is +0 and x2_i is -0) -> -0] +# cupy 13.x follows numpy 1.x w/o weak promotion: result_type(int32, uint8, 1) != result_type(int32, uint8) +array_api_tests/test_data_type_functions.py::TestResultType::test_with_scalars + +# CuPy 14 does not support copy= dlpack argument +array_api_tests/test_dlpack.py::test_dunder_dlpack +array_api_tests/test_dlpack.py::test_from_dlpack + diff --git a/dask-xfails.txt b/dask-xfails.txt index d2474f9f..34d3afd6 100644 --- a/dask-xfails.txt +++ b/dask-xfails.txt @@ -12,8 +12,10 @@ array_api_tests/test_array_object.py::test_getitem_masking # zero division error, and typeerror: tuple indices must be integers or slices not tuple array_api_tests/test_creation_functions.py::test_eye -# finfo(float32).eps returns float32 but should return float +# attributes are np.float32 instead of float +# (see also https://github.com/data-apis/array-api/issues/405) array_api_tests/test_data_type_functions.py::test_finfo[float32] +array_api_tests/test_data_type_functions.py::test_finfo[complex64] # out[-1]=dask.array but should be some floating number # (I think the test is not forcing the op to be computed?) @@ -22,18 +24,23 @@ array_api_tests/test_creation_functions.py::test_linspace # Shape mismatch array_api_tests/test_indexing_functions.py::test_take +# missing `take_along_axis`, https://github.com/dask/dask/issues/3663 +array_api_tests/test_indexing_functions.py::test_take_along_axis + # Array methods and attributes not already on da.Array cannot be wrapped array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__] array_api_tests/test_has_names.py::test_has_names[array_method-to_device] array_api_tests/test_has_names.py::test_has_names[array_attribute-device] array_api_tests/test_has_names.py::test_has_names[array_attribute-mT] -# Fails because shape is NaN since we don't materialize it yet +# Data-dependent output shape +# These tests fail as array-api-tests doesn't cope with unknown shapes +# Also, output shape is (math.nan, ) instead of (None, ) +# Also, da.unique() doesn't accept equals_nan which causes non-compliant +# output when there are NaNs in the input. array_api_tests/test_searching_functions.py::test_nonzero array_api_tests/test_set_functions.py::test_unique_all array_api_tests/test_set_functions.py::test_unique_counts - -# Different error but same cause as above, we're just trying to do ndindex on nan shape array_api_tests/test_set_functions.py::test_unique_inverse array_api_tests/test_set_functions.py::test_unique_values @@ -122,6 +129,19 @@ array_api_tests/test_linalg.py::test_matrix_norm array_api_tests/test_linalg.py::test_qr array_api_tests/test_manipulation_functions.py::test_roll +# 2025.12 support +array_api_tests/test_has_names.py::test_has_names[manipulation-broadcast_shapes] +array_api_tests/test_signatures.py::test_func_signature[broadcast_shapes] +array_api_tests/test_data_type_functions.py::TestBroadcastShapes::test_broadcast_shapes +array_api_tests/test_data_type_functions.py::TestBroadcastShapes::test_empty +array_api_tests/test_data_type_functions.py::TestBroadcastShapes::test_error +array_api_tests/test_searching_functions.py::test_searchsorted_with_scalars + +array_api_tests/test_linalg.py::test_eig +array_api_tests/test_linalg.py::test_eigvals + +array_api_tests/test_manipulation_functions.py::TestExpandDims::test_expand_dims_tuples + # Stubs have a comment: (**note**: libraries may return ``NaN`` to match Python behavior.) array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] @@ -141,3 +161,20 @@ array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] + +# complex cases +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is +infinity and imag(x_i) is +0) -> +infinity + 0j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is -infinity and imag(x_i) is +infinity) -> -1 + 0j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is +infinity and imag(x_i) is +infinity) -> infinity + NaN j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is -infinity and imag(x_i) is NaN) -> -1 + 0j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is +infinity and imag(x_i) is NaN) -> infinity + NaN j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is NaN and imag(x_i) is +0) -> NaN + 0j] +array_api_tests/test_special_cases.py::test_unary[sqrt(real(x_i) is +infinity and isfinite(imag(x_i)) and imag(x_i) > 0) -> +0 + infinity j] +array_api_tests/test_special_cases.py::test_unary[tanh(real(x_i) is +infinity and isfinite(imag(x_i)) and imag(x_i) > 0) -> 1 + 0j] +array_api_tests/test_special_cases.py::test_unary[expm1((real(x_i) is +0 or real(x_i) == -0) and imag(x_i) is +0) -> 0 + 0j] +array_api_tests/test_special_cases.py::test_unary[acosh(real(x_i) is +0 and imag(x_i) is NaN) -> NaN \xb1 \u03c0j/2] + +# no dlpack support +array_api_tests/test_dlpack.py::test_dlpack_device +array_api_tests/test_dlpack.py::test_dunder_dlpack +array_api_tests/test_dlpack.py::test_from_dlpack diff --git a/docs/changelog.md b/docs/changelog.md index 1de11606..55449df0 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -1,6 +1,145 @@ # Changelog -## 1.11.0 (2025-XX-XX) +## 1.14.0 (2026-02-26) + +### Major changes + +This release targets the 2025.12 Array API revision. This includes + + - `__array_api_version__` for the wrapped APIs is now set to `2025.12`; + - wrappers for `linalg.eig` and `linalg.eigvals`; + - wrappers for `isin` and `searchsorted` to accept Python scalars; + - wrappers for `expand_dims` accepting tuple axes; + - `broadcast_arrays`, `meshgrid` and `__array_api_info__().devices()` have been + changed to return tuples, not lists; + +Additionally, + + - `clip` wrappers have been fixed to be compatible with `torch.vmap`. + + +### Minor changes + + - `expand_dims` wrappers have been fixed to accept its `axis` argument as a keyword + or positional argument; + - `torch.clip` wrappers have been fixed to correctly handle `nan` scalars; + - `torch.repeat` wrapper has been fixed to not error out for short integers; + + +The following users contributed to this release: + +Evgeni Burovski, +Josh Soref. + + +## 1.13.0 (2025-12-28) + + +### Major changes + +- Support for Python 3.14 has been added. +- Symbols exported in public namespaces have been reviewed and adjusted. +- `torch.take` and `torch.take_along_axis` now support negative indices. +- `torch.meshgrid` now correctly processes the `indexing` argument. +- View/copy semantics are now observed for the `ceil`, `floor`, and `trunc` functions. + +### Minor changes + +- `array_namespace` has been sped up via caching. +- The `stable` parameter of `torch.argsort` now defaults to `True`, per the standard. +- Type annotations have seen progress. +- `is_jax_array` has been adjusted for compatibility with `jax>=0.8.2` + + +The following users contributed to this release: + +Evgeni Burovski, +Guido Imperiale, +Lucas Colley, +Arthur Lacote, +Martin Schuck, +Matt Haberland. + + +## 1.12.0 (2025-05-13) + + +### Major changes + +- The build system has been updated to use `pyproject.toml` instead of `setup.py` +- Support for Python 3.9 has been dropped. The minimum supported Python version is now + 3.10; the minimum supported NumPy version is 1.22. +- The `linalg` extension works correctly with `pytorch>=2.7`. +- Multiple improvements to handling of devices in CuPy and PyTorch backends. + Support for multiple devices in CuPy is still immature and you should use + context managers rather than relying on input-output device propagation or + on the `device` parameter. Please report any issues you encounter. + +### Minor changes + +- `finfo` and `iinfo` functions now accept array arguments, in accordance with the + Array API spec; +- `torch.asarray` function propagates the device of the input array. This works around + the [pytorch issue #150199](https://github.com/pytorch/pytorch/issues/150199); +- `torch.repeat` function is now available; +- `torch.count_nonzero` function now correctly handles the case of a tuple `axis` + arguments and `keepdims=True`; +- `torch.meshgrid` wrapper defaults to `indexing="xy"`, in accordance with the + array API specification; +- `cupy.asarray` function now implements the `copy=False` argument, albeit + at the cost of risking to make a temporary copy. +- In `numpy.take_along_axis` and `cupy.take_along_axis` the `axis` parameter now + defaults to -1, in accordance to the Array API spec. + + +The following users contributed to this release: + +Evgeni Burovski, +Lucas Colley, +Neil Girdhar, +Joren Hammudoglu, +Guido Imperiale + + +## 1.11.2 (2025-03-20) + +This is a bugfix release with no new features compared to version 1.11. + +- fix the `result_type` wrapper for pytorch. Previously, `result_type` had multiple + issues with scalar arguments. +- fix several issues with `clip` wrappers. Previously, `clip` was failing to allow + behaviors which are unspecified by the 2024.12 standard but allowed by the array + libraries. + +The following users contributed to this release: + +Evgeni Burovski +Guido Imperiale +Magnus Dalen Kvalevåg + + +## 1.11.1 (2025-03-04) + +This is a bugfix release with no new features compared to version 1.11. + +### Major Changes + +- fix `count_nonzero` wrappers: work around the lack of the `keepdims` argument in + several array libraries (torch, dask, cupy); work around numpy returning python + ints in for some input combinations. + +### Minor Changes + +- runnings self-tests does not require all array libraries. Missing libraries are + skipped. + +The following users contributed to this release: + +Evgeni Burovski +Guido Imperiale + + +## 1.11.0 (2025-02-27) ### Major Changes @@ -114,11 +253,11 @@ Thomas Li `xp.__array_namespace_info__()`. - Various fixes to the `clip()` wrappers. -- `torch.conj` now wrapps `torch.conj_physical`, which makes a copy rather +- `torch.conj` now wraps `torch.conj_physical`, which makes a copy rather than setting the conjugation bit, as arrays with the conjugation bit set do not support some APIs. -- `torch.sign` is now wrapped to support complex numbers and propogate nans +- `torch.sign` is now wrapped to support complex numbers and propagate nans properly. ### Minor Changes @@ -191,7 +330,7 @@ Thomas Li - New flag `use_compat` to {func}`~.array_namespace` to force the use or non-use of the compat wrapper namespace. The default is to return a compat - namespace when it is appropiate. + namespace when it is appropriate. - Fix the `copy` flag to `asarray` for NumPy, CuPy, and Dask. diff --git a/docs/dev/releasing.md b/docs/dev/releasing.md index 1ee17709..21d3c36a 100644 --- a/docs/dev/releasing.md +++ b/docs/dev/releasing.md @@ -24,7 +24,7 @@ This does mean you can ignore CI failures, but ideally you should fix any failures or update the `*-xfails.txt` files before tagging, so that CI and - the CuPy tests fully pass. Otherwise it will be hard to tell what things are + the CuPy tests fully pass. Otherwise, it will be hard to tell what things are breaking in the future. It's also a good idea to remove any xpasses from those files (but be aware that some xfails are from flaky failures, so unless you know the underlying issue has been fixed, an xpass test is @@ -100,6 +100,11 @@ docs update (the docs are published automatically from the sources on `main`). +- [ ] **Bump the `__version__` attribute in `__init__.py`** + + After an M.N.0 release, further development is towards version `M.(N+1).0`, thus the main branch's + version is `M.(N+1).0.dev0`. + - [ ] **Update conda-forge.** After the PyPI package is published, the conda-forge bot should update the diff --git a/docs/dev/tests.md b/docs/dev/tests.md index 6d9d1d7b..18fb7cf5 100644 --- a/docs/dev/tests.md +++ b/docs/dev/tests.md @@ -7,7 +7,7 @@ the array API standard. There are also array-api-compat specific tests in These tests should be limited to things that are not tested by the test suite, e.g., tests for [helper functions](../helper-functions.rst) or for behavior that is not strictly required by the standard. To run these tests, install the -dependencies from `requirements-dev.txt` (array-api-compat has [no hard +dependencies from the `dev` optional group (array-api-compat has [no hard runtime dependencies](no-dependencies)). array-api-tests is run against all supported libraries are tested on CI diff --git a/docs/requirements.txt b/docs/requirements.txt deleted file mode 100644 index dbec7740..00000000 --- a/docs/requirements.txt +++ /dev/null @@ -1,6 +0,0 @@ -furo -linkify-it-py -myst-parser -sphinx -sphinx-copybutton -sphinx-autobuild diff --git a/docs/supported-array-libraries.md b/docs/supported-array-libraries.md index 4519c4ac..46fcdc27 100644 --- a/docs/supported-array-libraries.md +++ b/docs/supported-array-libraries.md @@ -36,23 +36,16 @@ deviations from the standard should be noted: 50](https://numpy.org/neps/nep-0050-scalar-promotion.html) and https://github.com/numpy/numpy/issues/22341) -- `asarray()` does not support `copy=False`. - - Functions which are not wrapped may not have the same type annotations as the spec. - Functions which are not wrapped may not use positional-only arguments. -The minimum supported NumPy version is 1.21. However, this older version of +The minimum supported NumPy version is 1.22. However, this older version of NumPy has a few issues: - `unique_*` will not compare nans as unequal. -- `finfo()` has no `smallest_normal`. - No `from_dlpack` or `__dlpack__`. -- `argmax()` and `argmin()` do not have `keepdims`. -- `qr()` doesn't support matrix stacks. -- `asarray()` doesn't support `copy=True` (as noted above, `copy=False` is not - supported even in the latest NumPy). - Type promotion behavior will be value based for 0-D arrays (and there is no `NPY_PROMOTION_STATE=weak` to disable this). @@ -72,8 +65,8 @@ version. attribute in the spec. Use the {func}`~.size()` helper function as a portable workaround. -- PyTorch does not have unsigned integer types other than `uint8`, and no - attempt is made to implement them here. +- PyTorch has incomplete support for unsigned integer types other + than `uint8`, and no attempt is made to implement them here. - PyTorch has type promotion semantics that differ from the array API specification for 0-D tensor objects. The array functions in this wrapper @@ -100,8 +93,6 @@ version. - As with NumPy, type annotations and positional-only arguments may not exactly match the spec for functions that are not wrapped at all. -The minimum supported PyTorch version is 1.13. - (jax-support)= ## [JAX](https://jax.readthedocs.io/en/latest/) @@ -131,8 +122,6 @@ For `linalg`, several methods are missing, for example: - `matrix_rank` Other methods may only be partially implemented or return incorrect results at times. -The minimum supported Dask version is 2023.12.0. - (sparse-support)= ## [Sparse](https://sparse.pydata.org/en/stable/) diff --git a/numpy-1-21-xfails.txt b/numpy-1-22-xfails.txt similarity index 67% rename from numpy-1-21-xfails.txt rename to numpy-1-22-xfails.txt index 28c0e13a..20477f99 100644 --- a/numpy-1-21-xfails.txt +++ b/numpy-1-22-xfails.txt @@ -1,8 +1,7 @@ -# asarray(copy=False) is not yet implemented -array_api_tests/test_creation_functions.py::test_asarray_arrays - -# finfo(float32).eps returns float32 but should return float +# attributes are np.float32 instead of float +# (see also https://github.com/data-apis/array-api/issues/405) array_api_tests/test_data_type_functions.py::test_finfo[float32] +array_api_tests/test_data_type_functions.py::test_finfo[complex64] # Array methods and attributes not already on np.ndarray cannot be wrapped array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__] @@ -37,38 +36,24 @@ array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and # https://github.com/numpy/numpy/issues/21213 array_api_tests/meta/test_hypothesis_helpers.py::test_symmetric_matrices -# NumPy 1.21 specific XFAILS +# NumPy 1.22 specific XFAILS ############################ -# finfo has no smallest_normal -array_api_tests/test_data_type_functions.py::test_finfo[float64] - -# dlpack stuff -array_api_tests/test_has_names.py::test_has_names[creation-from_dlpack] -array_api_tests/test_has_names.py::test_has_names[array_method-__dlpack__] -array_api_tests/test_has_names.py::test_has_names[array_method-__dlpack_device__] -array_api_tests/test_signatures.py::test_func_signature[from_dlpack] -array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__] -array_api_tests/test_signatures.py::test_array_method_signature[__dlpack_device__] - -# qr() doesn't support matrix stacks -array_api_tests/test_linalg.py::test_qr - # cross has some promotion bug that is fixed in newer numpy versions array_api_tests/test_linalg.py::test_cross +# linspace(-0.0, -1.0, num=1) returns +0.0 instead of -0.0. +# Fixed in newer numpy versions. +array_api_tests/test_creation_functions.py::test_linspace + # vector_norm with ord=-1 which has since been fixed # https://github.com/numpy/numpy/issues/21083 array_api_tests/test_linalg.py::test_vector_norm -# argmax and argmin do not support keepdims -array_api_tests/test_searching_functions.py::test_argmax -array_api_tests/test_searching_functions.py::test_argmin -array_api_tests/test_signatures.py::test_func_signature[argmax] -array_api_tests/test_signatures.py::test_func_signature[argmin] - -# NumPy 1.21 doesn't support NPY_PROMOTION_STATE=weak, so many tests fail with +# NumPy 1.22 doesn't support NPY_PROMOTION_STATE=weak, so many tests fail with # type promotion issues +# NOTE: some of these may not fail until one runs array-api-tests with +# --max-examples 100000 array_api_tests/test_manipulation_functions.py::test_concat array_api_tests/test_operators_and_elementwise_functions.py::test_add[__add__(x, s)] array_api_tests/test_operators_and_elementwise_functions.py::test_add[__add__(x1, x2)] @@ -107,6 +92,7 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[_ array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[floor_divide(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_greater[__gt__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_greater[greater(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[__ge__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[greater_equal(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_hypot array_api_tests/test_operators_and_elementwise_functions.py::test_less[__lt__(x1, x2)] @@ -134,57 +120,29 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__isu array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__sub__(x, s)] array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__sub__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[subtract(x1, x2)] -array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i < 0 and x2_i is +0) -> -infinity] -array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i < 0 and x2_i is -0) -> +infinity] -array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i > 0 and x2_i is +0) -> +infinity] -array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i > 0 and x2_i is -0) -> -infinity] -array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +0 and x2_i < 0) -> -0] -array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +0 and x2_i > 0) -> +0] -array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -0 and x2_i < 0) -> +0] -array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -0 and x2_i > 0) -> -0] -array_api_tests/test_special_cases.py::test_binary[__mod__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> x2_i] -array_api_tests/test_special_cases.py::test_binary[__mod__(isfinite(x1_i) and x1_i < 0 and x2_i is -infinity) -> x1_i] -array_api_tests/test_special_cases.py::test_binary[__mod__(isfinite(x1_i) and x1_i > 0 and x2_i is +infinity) -> x1_i] -array_api_tests/test_special_cases.py::test_binary[__mod__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> x2_i] -array_api_tests/test_special_cases.py::test_binary[__mod__(x1_i is +0 and x2_i < 0) -> -0] -array_api_tests/test_special_cases.py::test_binary[__mod__(x1_i is +0 and x2_i > 0) -> +0] -array_api_tests/test_special_cases.py::test_binary[__mod__(x1_i is -0 and x2_i < 0) -> -0] -array_api_tests/test_special_cases.py::test_binary[__mod__(x1_i is -0 and x2_i > 0) -> +0] -array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i < 0 and isfinite(x1_i) and isfinite(x2_i) and not x2_i.is_integer()) -> NaN] -array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is +0 and x2_i < 0) -> +infinity] -array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is +0 and x2_i > 0) -> +0] -array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -0 and x2_i < 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity] -array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -0 and x2_i < 0 and x2_i.is_integer() and x2_i % 2 == 1) -> -infinity] -array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -0 and x2_i > 0 and x2_i.is_integer() and x2_i % 2 == 1) -> -0] -array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -infinity and x2_i < 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0] -array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -infinity and x2_i < 0 and x2_i.is_integer() and x2_i % 2 == 1) -> -0] -array_api_tests/test_special_cases.py::test_binary[atan2(x1_i < 0 and x2_i is +0) -> roughly -pi/2] -array_api_tests/test_special_cases.py::test_binary[atan2(x1_i < 0 and x2_i is -0) -> roughly -pi/2] -array_api_tests/test_special_cases.py::test_binary[atan2(x1_i > 0 and x2_i is +0) -> roughly +pi/2] -array_api_tests/test_special_cases.py::test_binary[atan2(x1_i > 0 and x2_i is -0) -> roughly +pi/2] -array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i < 0 and x2_i is +0) -> -infinity] -array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i < 0 and x2_i is -0) -> +infinity] -array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i > 0 and x2_i is +0) -> +infinity] -array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i > 0 and x2_i is -0) -> -infinity] -array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +0 and x2_i < 0) -> -0] -array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +0 and x2_i > 0) -> +0] -array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -0 and x2_i < 0) -> +0] -array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -0 and x2_i > 0) -> -0] -array_api_tests/test_special_cases.py::test_binary[remainder(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> x2_i] -array_api_tests/test_special_cases.py::test_binary[remainder(isfinite(x1_i) and x1_i < 0 and x2_i is -infinity) -> x1_i] -array_api_tests/test_special_cases.py::test_binary[remainder(isfinite(x1_i) and x1_i > 0 and x2_i is +infinity) -> x1_i] -array_api_tests/test_special_cases.py::test_binary[remainder(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> x2_i] -array_api_tests/test_special_cases.py::test_binary[remainder(x1_i is +0 and x2_i < 0) -> -0] -array_api_tests/test_special_cases.py::test_binary[remainder(x1_i is +0 and x2_i > 0) -> +0] -array_api_tests/test_special_cases.py::test_binary[remainder(x1_i is -0 and x2_i < 0) -> -0] -array_api_tests/test_special_cases.py::test_binary[remainder(x1_i is -0 and x2_i > 0) -> +0] +array_api_tests/test_searching_functions.py::test_where array_api_tests/test_special_cases.py::test_iop[__iadd__(x1_i is -0 and x2_i is -0) -> -0] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[add] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[divide] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[hypot] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[subtract] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[logaddexp] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[nextafter] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[multiply] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[maximum] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[minimum] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[copysign] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[pow] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[minimum] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[atan2] + +array_api_tests/test_operators_and_elementwise_functions.py::test_where_with_scalars + # 2023.12 support +array_api_tests/test_has_names.py::test_has_names[creation-from_dlpack] array_api_tests/test_signatures.py::test_func_signature[from_dlpack] array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__] -# uint64 repeats not supported -array_api_tests/test_manipulation_functions.py::test_repeat # 2024.12 support array_api_tests/test_signatures.py::test_func_signature[bitwise_and] @@ -192,6 +150,14 @@ array_api_tests/test_signatures.py::test_func_signature[bitwise_left_shift] array_api_tests/test_signatures.py::test_func_signature[bitwise_or] array_api_tests/test_signatures.py::test_func_signature[bitwise_right_shift] array_api_tests/test_signatures.py::test_func_signature[bitwise_xor] +array_api_tests/test_data_type_functions.py::TestResultType::test_with_scalars + +# 2025.12 support + +# older numpies return lists not tuples +array_api_tests/test_creation_functions.py::test_meshgrid +array_api_tests/test_data_type_functions.py::test_broadcast_arrays + # Stubs have a comment: (**note**: libraries may return ``NaN`` to match Python behavior.); Apparently,NumPy does just that array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] @@ -212,3 +178,21 @@ array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] + +# complex special cases +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is +infinity and imag(x_i) is +0) -> +infinity + 0j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is -infinity and imag(x_i) is +infinity) -> -1 + 0j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is +infinity and imag(x_i) is +infinity) -> infinity + NaN j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is -infinity and imag(x_i) is NaN) -> -1 + 0j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is +infinity and imag(x_i) is NaN) -> infinity + NaN j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is NaN and imag(x_i) is +0) -> NaN + 0j] +array_api_tests/test_special_cases.py::test_unary[expm1((real(x_i) is +0 or real(x_i) == -0) and imag(x_i) is +0) -> 0 + 0j] + +array_api_tests/test_special_cases.py::test_unary[sqrt(real(x_i) is +infinity and isfinite(imag(x_i)) and imag(x_i) > 0) -> +0 + infinity j] +array_api_tests/test_special_cases.py::test_unary[acosh(real(x_i) is +0 and imag(x_i) is NaN) -> NaN \xb1 \u03c0j/2] +array_api_tests/test_special_cases.py::test_unary[tanh(real(x_i) is +infinity and isfinite(imag(x_i)) and imag(x_i) > 0) -> 1 + 0j] + +# no dlpack support +array_api_tests/test_dlpack.py::test_dlpack_device +array_api_tests/test_dlpack.py::test_dunder_dlpack +array_api_tests/test_dlpack.py::test_from_dlpack diff --git a/numpy-1-26-xfails.txt b/numpy-1-26-xfails.txt index 80790534..45d62f23 100644 --- a/numpy-1-26-xfails.txt +++ b/numpy-1-26-xfails.txt @@ -1,5 +1,7 @@ -# finfo(float32).eps returns float32 but should return float +# attributes are np.float32 instead of float +# (see also https://github.com/data-apis/array-api/issues/405) array_api_tests/test_data_type_functions.py::test_finfo[float32] +array_api_tests/test_data_type_functions.py::test_finfo[complex64] # Array methods and attributes not already on np.ndarray cannot be wrapped array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__] @@ -37,15 +39,19 @@ array_api_tests/meta/test_hypothesis_helpers.py::test_symmetric_matrices # 2023.12 support array_api_tests/test_signatures.py::test_func_signature[from_dlpack] array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__] -# uint64 repeats not supported -array_api_tests/test_manipulation_functions.py::test_repeat # 2024.12 support -array_api_tests/test_signatures.py::test_func_signature[bitwise_and] -array_api_tests/test_signatures.py::test_func_signature[bitwise_left_shift] -array_api_tests/test_signatures.py::test_func_signature[bitwise_or] -array_api_tests/test_signatures.py::test_func_signature[bitwise_right_shift] -array_api_tests/test_signatures.py::test_func_signature[bitwise_xor] +array_api_tests/test_data_type_functions.py::TestResultType::test_with_scalars +array_api_tests/test_operators_and_elementwise_functions.py::test_where_with_scalars + +# 2025.12 support + +# older numpies return lists not tuples +array_api_tests/test_creation_functions.py::test_meshgrid +array_api_tests/test_data_type_functions.py::test_broadcast_arrays + +# observed with numpy==1.26 only, looks like is fixed on numpy 2.x +array_api_tests/test_set_functions.py::TestIsin::test_isin_scalars # Stubs have a comment: (**note**: libraries may return ``NaN`` to match Python behavior.); Apparently, NumPy does just that array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] @@ -66,3 +72,22 @@ array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] + +# complex special cases +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is +infinity and imag(x_i) is +0) -> +infinity + 0j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is -infinity and imag(x_i) is +infinity) -> -1 + 0j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is +infinity and imag(x_i) is +infinity) -> infinity + NaN j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is -infinity and imag(x_i) is NaN) -> -1 + 0j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is +infinity and imag(x_i) is NaN) -> infinity + NaN j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is NaN and imag(x_i) is +0) -> NaN + 0j] +array_api_tests/test_special_cases.py::test_unary[expm1((real(x_i) is +0 or real(x_i) == -0) and imag(x_i) is +0) -> 0 + 0j] +array_api_tests/test_special_cases.py::test_unary[sqrt(real(x_i) is +infinity and isfinite(imag(x_i)) and imag(x_i) > 0) -> +0 + infinity j] +array_api_tests/test_special_cases.py::test_unary[acosh(real(x_i) is +0 and imag(x_i) is NaN) -> NaN \xb1 \u03c0j/2] +array_api_tests/test_special_cases.py::test_unary[tanh(real(x_i) is +infinity and isfinite(imag(x_i)) and imag(x_i) > 0) -> 1 + 0j] + +# no dlpack support +array_api_tests/test_dlpack.py::test_dlpack_device +array_api_tests/test_dlpack.py::test_dunder_dlpack +array_api_tests/test_dlpack.py::test_from_dlpack + + diff --git a/numpy-dev-xfails.txt b/numpy-dev-xfails.txt index 98659710..45d3338f 100644 --- a/numpy-dev-xfails.txt +++ b/numpy-dev-xfails.txt @@ -1,22 +1,13 @@ -# finfo(float32).eps returns float32 but should return float +# attributes are np.float32 instead of float +# (see also https://github.com/data-apis/array-api/issues/405) array_api_tests/test_data_type_functions.py::test_finfo[float32] +array_api_tests/test_data_type_functions.py::test_finfo[complex64] # The test suite cannot properly get the signature for vecdot # https://github.com/numpy/numpy/pull/26237 array_api_tests/test_signatures.py::test_func_signature[vecdot] array_api_tests/test_signatures.py::test_extension_func_signature[linalg.vecdot] -# 2023.12 support -# uint64 repeats not supported -array_api_tests/test_manipulation_functions.py::test_repeat - -# 2024.12 support -array_api_tests/test_signatures.py::test_func_signature[bitwise_and] -array_api_tests/test_signatures.py::test_func_signature[bitwise_left_shift] -array_api_tests/test_signatures.py::test_func_signature[bitwise_or] -array_api_tests/test_signatures.py::test_func_signature[bitwise_right_shift] -array_api_tests/test_signatures.py::test_func_signature[bitwise_xor] - # Stubs have a comment: (**note**: libraries may return ``NaN`` to match Python behavior.); Apparently, NumPy does just that array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] @@ -36,3 +27,15 @@ array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] + +# complex special cases +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is +infinity and imag(x_i) is +0) -> +infinity + 0j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is -infinity and imag(x_i) is +infinity) -> -1 + 0j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is +infinity and imag(x_i) is +infinity) -> infinity + NaN j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is -infinity and imag(x_i) is NaN) -> -1 + 0j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is +infinity and imag(x_i) is NaN) -> infinity + NaN j] +array_api_tests/test_special_cases.py::test_unary[expm1((real(x_i) is +0 or real(x_i) == -0) and imag(x_i) is +0) -> 0 + 0j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is NaN and imag(x_i) is +0) -> NaN + 0j] +array_api_tests/test_special_cases.py::test_unary[sqrt(real(x_i) is +infinity and isfinite(imag(x_i)) and imag(x_i) > 0) -> +0 + infinity j] +array_api_tests/test_special_cases.py::test_unary[tanh(real(x_i) is +infinity and isfinite(imag(x_i)) and imag(x_i) > 0) -> 1 + 0j] + diff --git a/numpy-skips.txt b/numpy-skips.txt index cbf7235b..e69de29b 100644 --- a/numpy-skips.txt +++ b/numpy-skips.txt @@ -1,11 +0,0 @@ -# These tests cause a core dump on CI, so we have to skip them entirely -array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__ifloordiv__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__ifloordiv__(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[floor_divide(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__imod__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__imod__(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__mod__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__mod__(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[remainder(x1, x2)] diff --git a/numpy-xfails.txt b/numpy-xfails.txt index 0885dcaa..d8707c39 100644 --- a/numpy-xfails.txt +++ b/numpy-xfails.txt @@ -1,23 +1,14 @@ -# finfo(float32).eps returns float32 but should return float +# attributes are np.float32 instead of float +# (see also https://github.com/data-apis/array-api/issues/405) array_api_tests/test_data_type_functions.py::test_finfo[float32] +array_api_tests/test_data_type_functions.py::test_finfo[complex64] # The test suite cannot properly get the signature for vecdot # https://github.com/numpy/numpy/pull/26237 array_api_tests/test_signatures.py::test_func_signature[vecdot] array_api_tests/test_signatures.py::test_extension_func_signature[linalg.vecdot] -# 2023.12 support -array_api_tests/test_signatures.py::test_func_signature[from_dlpack] -array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__] -# uint64 repeats not supported -array_api_tests/test_manipulation_functions.py::test_repeat - -# 2024.12 support -array_api_tests/test_signatures.py::test_func_signature[bitwise_and] -array_api_tests/test_signatures.py::test_func_signature[bitwise_left_shift] -array_api_tests/test_signatures.py::test_func_signature[bitwise_or] -array_api_tests/test_signatures.py::test_func_signature[bitwise_right_shift] -array_api_tests/test_signatures.py::test_func_signature[bitwise_xor] +# Stubs have a comment: (**note**: libraries may return ``NaN`` to match Python behavior.); Apparently, NumPy does just that array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] @@ -37,3 +28,16 @@ array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] +# complex special cases +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is +infinity and imag(x_i) is +0) -> +infinity + 0j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is -infinity and imag(x_i) is +infinity) -> -1 + 0j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is +infinity and imag(x_i) is +infinity) -> infinity + NaN j] array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is -infinity and imag(x_i) is NaN) -> -1 + 0j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is +infinity and imag(x_i) is NaN) -> infinity + NaN j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is NaN and imag(x_i) is +0) -> NaN + 0j] +array_api_tests/test_special_cases.py::test_unary[sqrt(real(x_i) is +infinity and isfinite(imag(x_i)) and imag(x_i) > 0) -> +0 + infinity j] +array_api_tests/test_special_cases.py::test_unary[tanh(real(x_i) is +infinity and isfinite(imag(x_i)) and imag(x_i) > 0) -> 1 + 0j] + +array_api_tests/test_special_cases.py::test_unary[acosh(real(x_i) is +0 and imag(x_i) is NaN) -> NaN \xb1 \u03c0j/2] +array_api_tests/test_special_cases.py::test_unary[expm1((real(x_i) is +0 or real(x_i) == -0) and imag(x_i) is +0) -> 0 + 0j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is +infinity and imag(x_i) is +infinity) -> infinity + NaN j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is -infinity and imag(x_i) is NaN) -> -1 + 0j] diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..d7339170 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,131 @@ +[build-system] +requires = ["setuptools", "setuptools-scm"] +build-backend = "setuptools.build_meta" + +[project] +name = "array-api-compat" +dynamic = ["version"] +description = "A wrapper around NumPy and other array libraries to make them compatible with the Array API standard" +readme = "README.md" +requires-python = ">=3.10" +license = "MIT" +authors = [{name = "Consortium for Python Data API Standards"}] +classifiers = [ + "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", + "Topic :: Software Development :: Libraries :: Python Modules", + "Typing :: Typed", +] + +[project.optional-dependencies] +cupy = ["cupy"] +dask = ["dask>=2024.9.0"] +jax = ["jax"] +# Note: array-api-compat follows scikit-learn minimum dependencies, which support +# much older versions of NumPy than what SPEC0 recommends. +numpy = ["numpy>=1.22"] +pytorch = ["torch"] +sparse = ["sparse>=0.15.1"] +ndonnx = ["ndonnx"] +docs = [ + "furo", + "linkify-it-py", + "myst-parser", + "sphinx", + "sphinx-copybutton", + "sphinx-autobuild", +] +dev = [ + "array-api-strict", + "dask[array]>=2024.9.0", + "jax[cpu]", + "ndonnx", + "numpy>=1.22", + "pytest", + "torch", + "sparse>=0.15.1", +] + +[project.urls] +homepage = "https://data-apis.org/array-api-compat/" +repository = "https://github.com/data-apis/array-api-compat/" + +[tool.setuptools.dynamic] +version = {attr = "array_api_compat.__version__"} + +[tool.setuptools.packages.find] +include = ["array_api_compat*"] +namespaces = false + +[tool.ruff.lint] +preview = true +select = [ + # Defaults + "E4", "E7", "E9", "F", + # Additional rules + "B", "C4", "ISC", "PIE", "FLY", "PERF", "UP", "FURB", + # Useless import alias + "PLC0414", + # Unused `noqa` directive + "RUF100", +] + +ignore = [ + # Module import not at top of file + "E402", + # Do not use bare `except` + "E722", + # Use of `functools.cache` on methods can lead to memory leaks + "B019", + # No explicit `stacklevel` keyword argument found + "B028", + # Within an `except` clause, raise exceptions with `raise ... from ...` + "B904", + # `try`-`except` within a loop incurs performance overhead + "PERF203", +] + + +[tool.mypy] +files = ["array_api_compat"] +disallow_incomplete_defs = true +disallow_untyped_decorators = true +disallow_untyped_defs = false # TODO +ignore_missing_imports = false +no_implicit_optional = true +show_error_codes = true +warn_redundant_casts = true +warn_unused_ignores = true +warn_unreachable = true + +[[tool.mypy.overrides]] +module = ["cupy.*", "cupy_backends.*", "dask.*", "jax.*", "ndonnx.*", "sparse.*", "torch.*"] +ignore_missing_imports = true + + +[tool.pyright] +include = ["src", "tests"] +pythonPlatform = "All" + +reportAny = false +reportExplicitAny = false +# missing type stubs +reportAttributeAccessIssue = false +reportUnknownMemberType = false +reportUnknownVariableType = false +# Redundant with mypy checks +reportMissingImports = false +reportMissingTypeStubs = false +# false positives for input validation +reportUnreachable = false +# ruff handles this +reportUnusedParameter = false + +executionEnvironments = [ + { root = "array_api_compat" }, +] diff --git a/requirements-dev.txt b/requirements-dev.txt deleted file mode 100644 index c9d10f71..00000000 --- a/requirements-dev.txt +++ /dev/null @@ -1,8 +0,0 @@ -array-api-strict -dask[array] -jax[cpu] -numpy -pytest -torch -sparse >=0.15.1 -ndonnx diff --git a/ruff.toml b/ruff.toml deleted file mode 100644 index 72e111b5..00000000 --- a/ruff.toml +++ /dev/null @@ -1,17 +0,0 @@ -[lint] -preview = true -select = [ -# Defaults -"E4", "E7", "E9", "F", -# Undefined export -"F822", -# Useless import alias -"PLC0414" -] - -ignore = [ - # Module import not at top of file - "E402", - # Do not use bare `except` - "E722" -] diff --git a/setup.py b/setup.py deleted file mode 100644 index 3d2b68a2..00000000 --- a/setup.py +++ /dev/null @@ -1,37 +0,0 @@ -from setuptools import setup, find_packages - -with open("README.md", "r") as fh: - long_description = fh.read() - -import array_api_compat - -setup( - name='array_api_compat', - version=array_api_compat.__version__, - packages=find_packages(include=["array_api_compat*"]), - author="Consortium for Python Data API Standards", - description="A wrapper around NumPy and other array libraries to make them compatible with the Array API standard", - long_description=long_description, - long_description_content_type="text/markdown", - url="https://data-apis.org/array-api-compat/", - license="MIT", - extras_require={ - "numpy": "numpy", - "cupy": "cupy", - "jax": "jax", - "pytorch": "pytorch", - "dask": "dask", - "sparse": "sparse >=0.15.1", - }, - python_requires=">=3.9", - classifiers=[ - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", - "Programming Language :: Python :: 3.12", - "Programming Language :: Python :: 3.13", - "License :: OSI Approved :: MIT License", - "Operating System :: OS Independent", - ] -) diff --git a/test_cupy.sh b/test_cupy.sh index a6974333..8ac72dc6 100755 --- a/test_cupy.sh +++ b/test_cupy.sh @@ -26,5 +26,5 @@ mkdir -p $SCRIPT_DIR/.hypothesis ln -s $SCRIPT_DIR/.hypothesis .hypothesis export ARRAY_API_TESTS_MODULE=array_api_compat.cupy -export ARRAY_API_TESTS_VERSION=2024.12 +export ARRAY_API_TESTS_VERSION=2025.12 pytest array_api_tests/ ${PYTEST_ARGS} --xfails-file $SCRIPT_DIR/cupy-xfails.txt "$@" diff --git a/tests/test_all.py b/tests/test_all.py index 10a2a95d..d9350ce7 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -1,49 +1,317 @@ -""" -Test that files that define __all__ aren't missing any exports. +"""Test exported names""" -You can add names that shouldn't be exported to _all_ignore, like +import builtins -_all_ignore = ['sys'] +import numpy as np +import pytest -This is preferable to del-ing the names as this will break any name that is -used inside of a function. Note that names starting with an underscore are automatically ignored. -""" +from array_api_compat._internal import clone_module +from ._helpers import wrapped_libraries -import sys +NAMES = { + "": [ + # Inspection + "__array_api_version__", + "__array_namespace_info__", + # Submodules + "fft", + "linalg", + # Constants + "e", + "inf", + "nan", + "newaxis", + "pi", + # Creation Functions + "arange", + "asarray", + "empty", + "empty_like", + "eye", + "from_dlpack", + "full", + "full_like", + "linspace", + "meshgrid", + "ones", + "ones_like", + "tril", + "triu", + "zeros", + "zeros_like", + # Data Type Functions + "astype", + "can_cast", + "finfo", + "iinfo", + "isdtype", + "result_type", + # Data Types + "bool", + "int8", + "int16", + "int32", + "int64", + "uint8", + "uint16", + "uint32", + "uint64", + "float32", + "float64", + "complex64", + "complex128", + # Elementwise Functions + "abs", + "acos", + "acosh", + "add", + "asin", + "asinh", + "atan", + "atan2", + "atanh", + "bitwise_and", + "bitwise_left_shift", + "bitwise_invert", + "bitwise_or", + "bitwise_right_shift", + "bitwise_xor", + "ceil", + "clip", + "conj", + "copysign", + "cos", + "cosh", + "divide", + "equal", + "exp", + "expm1", + "floor", + "floor_divide", + "greater", + "greater_equal", + "hypot", + "imag", + "isfinite", + "isinf", + "isnan", + "less", + "less_equal", + "log", + "log1p", + "log2", + "log10", + "logaddexp", + "logical_and", + "logical_not", + "logical_or", + "logical_xor", + "maximum", + "minimum", + "multiply", + "negative", + "nextafter", + "not_equal", + "positive", + "pow", + "real", + "reciprocal", + "remainder", + "round", + "sign", + "signbit", + "sin", + "sinh", + "square", + "sqrt", + "subtract", + "tan", + "tanh", + "trunc", + # Indexing Functions + "take", + "take_along_axis", + # Linear Algebra Functions + "matmul", + "matrix_transpose", + "tensordot", + "vecdot", + # Manipulation Functions + "broadcast_arrays", + "broadcast_to", + "broadcast_shapes", + "concat", + "expand_dims", + "flip", + "moveaxis", + "permute_dims", + "repeat", + "reshape", + "roll", + "squeeze", + "stack", + "tile", + "unstack", + # Searching Functions + "argmax", + "argmin", + "count_nonzero", + "nonzero", + "searchsorted", + "where", + # Set Functions + "unique_all", + "unique_counts", + "unique_inverse", + "unique_values", + "isin", + # Sorting Functions + "argsort", + "sort", + # Statistical Functions + "cumulative_prod", + "cumulative_sum", + "max", + "mean", + "min", + "prod", + "std", + "sum", + "var", + # Utility Functions + "all", + "any", + "diff", + ], + "fft": [ + "fft", + "ifft", + "fftn", + "ifftn", + "rfft", + "irfft", + "rfftn", + "irfftn", + "hfft", + "ihfft", + "fftfreq", + "rfftfreq", + "fftshift", + "ifftshift", + ], + "linalg": [ + "cholesky", + "cross", + "det", + "diagonal", + "eigh", + "eigvalsh", + "eig", + "eigvals", + "inv", + "matmul", + "matrix_norm", + "matrix_power", + "matrix_rank", + "matrix_transpose", + "outer", + "pinv", + "qr", + "slogdet", + "solve", + "svd", + "svdvals", + "tensordot", + "trace", + "vecdot", + "vector_norm", + ], +} -from ._helpers import import_, wrapped_libraries +XFAILS = { + ("numpy", ""): ["from_dlpack"] if np.__version__ < "1.23" else [], + ("dask.array", ""): ["from_dlpack", "take_along_axis", "broadcast_shapes"], + ("dask.array", "linalg"): [ + "cross", + "det", + "eigh", + "eigvalsh", + "eig", + "eigvals", + "matrix_power", + "pinv", + "slogdet", + ], +} + + +def all_names(mod): + """Return all names available in a module.""" + objs = {} + clone_module(mod.__name__, objs) + return set(objs) + + +def get_mod(library, module, *, compat): + if compat: + library = f"array_api_compat.{library}" + xp = pytest.importorskip(library) + return getattr(xp, module) if module else xp + + +@pytest.mark.parametrize("module", list(NAMES)) +@pytest.mark.parametrize("library", wrapped_libraries) +def test_array_api_names(library, module): + """Test that __all__ isn't missing any exports + dictated by the Standard. + """ + mod = get_mod(library, module, compat=True) + missing = set(NAMES[module]) - all_names(mod) + xfail = set(XFAILS.get((library, module), [])) + xpass = xfail - missing + fails = missing - xfail + assert not xpass, f"Names in XFAILS are defined: {xpass}" + assert not fails, f"Missing exports: {fails}" + + +@pytest.mark.parametrize("module", list(NAMES)) +@pytest.mark.parametrize("library", wrapped_libraries) +def test_compat_doesnt_hide_names(library, module): + """The base namespace can have more names than the ones explicitly exported + by array-api-compat. Test that we're not suppressing them. + """ + bare_mod = get_mod(library, module, compat=False) + compat_mod = get_mod(library, module, compat=True) + + missing = all_names(bare_mod) - all_names(compat_mod) + missing = {name for name in missing if not name.startswith("_")} + assert not missing, f"Non-Array API names have been hidden: {missing}" + + +@pytest.mark.parametrize("module", list(NAMES)) +@pytest.mark.parametrize("library", wrapped_libraries) +def test_compat_doesnt_add_names(library, module): + """Test that array-api-compat isn't adding names to the namespace + besides those defined by the Array API Standard. + """ + bare_mod = get_mod(library, module, compat=False) + compat_mod = get_mod(library, module, compat=True) + + aapi_names = set(NAMES[module]) + spurious = all_names(compat_mod) - all_names(bare_mod) - aapi_names + # Quietly ignore *Result dataclasses + spurious = {name for name in spurious if not name.endswith("Result")} + assert not spurious, ( + f"array-api-compat is adding non-Array API names: {spurious}" + ) -import pytest -@pytest.mark.parametrize("library", ["common"] + wrapped_libraries) -def test_all(library): - if library == "common": - import array_api_compat.common # noqa: F401 - else: - import_(library, wrapper=True) - - for mod_name in sys.modules: - if not mod_name.startswith('array_api_compat.' + library): - continue - - module = sys.modules[mod_name] - - # TODO: We should define __all__ in the __init__.py files and test it - # there too. - if not hasattr(module, '__all__'): - continue - - dir_names = [n for n in dir(module) if not n.startswith('_')] - if '__array_namespace_info__' in dir(module): - dir_names.append('__array_namespace_info__') - ignore_all_names = getattr(module, '_all_ignore', []) - ignore_all_names += ['annotations', 'TYPE_CHECKING'] - dir_names = set(dir_names) - set(ignore_all_names) - all_names = module.__all__ - - if set(dir_names) != set(all_names): - extra_dir = set(dir_names) - set(all_names) - extra_all = set(all_names) - set(dir_names) - assert not extra_dir, f"Some dir() names not included in __all__ for {mod_name}: {extra_dir}" - assert not extra_all, f"Some __all__ names not in dir() for {mod_name}: {extra_all}" +@pytest.mark.parametrize( + "name", [name for name in NAMES[""] if hasattr(builtins, name)] +) +@pytest.mark.parametrize("library", wrapped_libraries) +def test_builtins_collision(library, name): + """Test that xp.bool is not accidentally builtins.bool, etc.""" + xp = pytest.importorskip(f"array_api_compat.{library}") + assert getattr(xp, name) is not getattr(builtins, name) diff --git a/tests/test_array_namespace.py b/tests/test_array_namespace.py index 605c69a1..311efc37 100644 --- a/tests/test_array_namespace.py +++ b/tests/test_array_namespace.py @@ -2,70 +2,72 @@ import sys import warnings -import jax import numpy as np import pytest -import torch import array_api_compat from array_api_compat import array_namespace -from ._helpers import import_, all_libraries, wrapped_libraries +from ._helpers import all_libraries, wrapped_libraries, xfail + @pytest.mark.parametrize("use_compat", [True, False, None]) -@pytest.mark.parametrize("api_version", [None, "2021.12", "2022.12", "2023.12"]) +@pytest.mark.parametrize( + "api_version", [None, "2021.12", "2022.12", "2023.12", "2024.12"] +) @pytest.mark.parametrize("library", all_libraries) -def test_array_namespace(library, api_version, use_compat): - xp = import_(library) +def test_array_namespace(request, library, api_version, use_compat): + xp = pytest.importorskip(library) array = xp.asarray([1.0, 2.0, 3.0]) if use_compat and library not in wrapped_libraries: pytest.raises(ValueError, lambda: array_namespace(array, use_compat=use_compat)) return - if library == "ndonnx" and api_version in ("2021.12", "2022.12"): - pytest.skip("Unsupported API version") + if (library == "sparse" and api_version in ("2023.12", "2024.12")) or ( + library == "jax.numpy" and api_version in ("2021.12", "2022.12", "2023.12") + ): + xfail(request, "Unsupported API version") - namespace = array_namespace(array, api_version=api_version, use_compat=use_compat) + with warnings.catch_warnings(): + warnings.simplefilter('ignore', UserWarning) + namespace = array_namespace(array, api_version=api_version, use_compat=use_compat) if use_compat is False or use_compat is None and library not in wrapped_libraries: - if library == "jax.numpy" and use_compat is None: - import jax.numpy - if hasattr(jax.numpy, "__array_api_version__"): - # JAX v0.4.32 or later uses jax.numpy directly - assert namespace == jax.numpy - else: - # JAX v0.4.31 or earlier uses jax.experimental.array_api - import jax.experimental.array_api - assert namespace == jax.experimental.array_api + if library == "jax.numpy" and not hasattr(xp, "__array_api_version__"): + # Backwards compatibility for JAX <0.4.32 + import jax.experimental.array_api + assert namespace == jax.experimental.array_api else: assert namespace == xp + elif library == "dask.array": + assert namespace == array_api_compat.dask.array else: - if library == "dask.array": - assert namespace == array_api_compat.dask.array - else: - assert namespace == getattr(array_api_compat, library) + assert namespace == getattr(array_api_compat, library) if library == "numpy": # check that the same namespace is returned for NumPy scalars - scalar_namespace = array_namespace( - xp.float64(0.0), api_version=api_version, use_compat=use_compat - ) - assert scalar_namespace == namespace - - # Check that array_namespace works even if jax.experimental.array_api - # hasn't been imported yet (it monkeypatches __array_namespace__ - # onto JAX arrays, but we should support them regardless). The only way to - # do this is to use a subprocess, since we cannot un-import it and another - # test probably already imported it. - if library == "jax.numpy" and sys.version_info >= (3, 9): - code = f"""\ + with warnings.catch_warnings(): + warnings.simplefilter('ignore', UserWarning) + + scalar_namespace = array_namespace( + xp.float64(0.0), api_version=api_version, use_compat=use_compat + ) + assert scalar_namespace == namespace + + +def test_jax_backwards_compat(): + """On JAX <0.4.32, test that array_namespace works even if + jax.experimental.array_api has not been imported yet. + """ + pytest.importorskip("jax") + code = """\ import sys import jax.numpy import array_api_compat -array = jax.numpy.asarray([1.0, 2.0, 3.0]) +array = jax.numpy.asarray([1.0, 2.0, 3.0]) assert 'jax.experimental.array_api' not in sys.modules -namespace = array_api_compat.array_namespace(array, api_version={api_version!r}) +namespace = array_api_compat.array_namespace(array) if hasattr(jax.numpy, '__array_api_version__'): assert namespace == jax.numpy @@ -73,13 +75,16 @@ def test_array_namespace(library, api_version, use_compat): import jax.experimental.array_api assert namespace == jax.experimental.array_api """ - subprocess.run([sys.executable, "-c", code], check=True) + subprocess.check_call([sys.executable, "-c", code]) + def test_jax_zero_gradient(): + jax = pytest.importorskip("jax") jx = jax.numpy.arange(4) jax_zero = jax.vmap(jax.grad(jax.numpy.float32, allow_int=True))(jx) assert array_namespace(jax_zero) is array_namespace(jx) + def test_array_namespace_errors(): pytest.raises(TypeError, lambda: array_namespace([1])) pytest.raises(TypeError, lambda: array_namespace()) @@ -88,38 +93,31 @@ def test_array_namespace_errors(): pytest.raises(TypeError, lambda: array_namespace((x, x))) pytest.raises(TypeError, lambda: array_namespace(x, (x, x))) -def test_array_namespace_errors_torch(): - y = torch.asarray([1, 2]) - x = np.asarray([1, 2]) - pytest.raises(TypeError, lambda: array_namespace(x, y)) - -def test_api_version_torch(): - x = torch.asarray([1, 2]) - torch_ = import_("torch", wrapper=True) - assert array_namespace(x, api_version="2023.12") == torch_ - assert array_namespace(x, api_version=None) == torch_ - assert array_namespace(x) == torch_ - # Should issue a warning - with warnings.catch_warnings(record=True) as w: - assert array_namespace(x, api_version="2021.12") == torch_ - assert len(w) == 1 - assert "2021.12" in str(w[0].message) - - # Should issue a warning - with warnings.catch_warnings(record=True) as w: - assert array_namespace(x, api_version="2022.12") == torch_ - assert len(w) == 1 - assert "2022.12" in str(w[0].message) - - pytest.raises(ValueError, lambda: array_namespace(x, api_version="2020.12")) + +@pytest.mark.parametrize("library", all_libraries) +def test_array_namespace_many_args(library): + xp = pytest.importorskip(library) + a = xp.asarray(1) + b = xp.asarray(2) + assert array_namespace(a, b) is array_namespace(a) + + +def test_array_namespace_mismatch(): + xp = pytest.importorskip("array_api_strict") + with pytest.raises(TypeError, match="Multiple namespaces"): + array_namespace(np.asarray(1), xp.asarray(1)) + def test_get_namespace(): # Backwards compatible wrapper assert array_api_compat.get_namespace is array_namespace -def test_python_scalars(): - a = torch.asarray([1, 2]) - xp = import_("torch", wrapper=True) + +@pytest.mark.parametrize("library", all_libraries) +def test_python_scalars(library): + xp = pytest.importorskip(library) + a = xp.asarray([1, 2]) + xp = array_namespace(a) pytest.raises(TypeError, lambda: array_namespace(1)) pytest.raises(TypeError, lambda: array_namespace(1.0)) @@ -127,8 +125,8 @@ def test_python_scalars(): pytest.raises(TypeError, lambda: array_namespace(True)) pytest.raises(TypeError, lambda: array_namespace(None)) - assert array_namespace(a, 1) == xp - assert array_namespace(a, 1.0) == xp - assert array_namespace(a, 1j) == xp - assert array_namespace(a, True) == xp - assert array_namespace(a, None) == xp + assert array_namespace(a, 1) is xp + assert array_namespace(a, 1.0) is xp + assert array_namespace(a, 1j) is xp + assert array_namespace(a, True) is xp + assert array_namespace(a, None) is xp diff --git a/tests/test_common.py b/tests/test_common.py index 32876e69..85ed032e 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -17,6 +17,7 @@ from array_api_compat import ( device, is_array_api_obj, is_lazy_array, is_writeable_array, size, to_device ) +from array_api_compat.common._helpers import _DASK_DEVICE from ._helpers import all_libraries, import_, wrapped_libraries, xfail @@ -189,23 +190,29 @@ class C: @pytest.mark.parametrize("library", all_libraries) -def test_device(library, request): +def test_device_to_device(library, request): if library == "ndonnx": - xfail(request, reason="Needs ndonnx >=0.9.4") + xfail(request, reason="Stub raises ValueError") + if library == "sparse": + xfail(request, reason="No __array_namespace_info__()") + if library == "array_api_strict": + if np.__version__ < "2": + xfail(request, reason="no copy argument of np.asarray") xp = import_(library, wrapper=True) + devices = xp.__array_namespace_info__().devices() - # We can't test much for device() and to_device() other than that - # x.to_device(x.device) works. - + # Default device x = xp.asarray([1, 2, 3]) dev = device(x) - x2 = to_device(x, dev) - assert device(x2) == device(x) - - x3 = xp.asarray(x, device=dev) - assert device(x3) == device(x) + for dev in devices: + if dev is None: # JAX >=0.5.3 + continue + if dev is _DASK_DEVICE: # TODO this needs a better design + continue + y = to_device(x, dev) + assert device(y) == dev @pytest.mark.parametrize("library", wrapped_libraries) @@ -234,6 +241,7 @@ def test_asarray_cross_library(source_library, target_library, request): # TODO: remove xfail once # https://github.com/dask/dask/issues/8260 is resolved xfail(request, reason="Bug in dask raising error on conversion") + elif ( source_library == "ndonnx" and target_library not in ("array_api_strict", "ndonnx", "numpy") @@ -241,6 +249,9 @@ def test_asarray_cross_library(source_library, target_library, request): xfail(request, reason="The truth value of lazy Array Array(dtype=Boolean) is unknown") elif source_library == "ndonnx" and target_library == "numpy": xfail(request, reason="produces numpy array of ndonnx scalar arrays") + elif target_library == "ndonnx" and source_library in ("torch", "dask.array", "jax.numpy"): + xfail(request, reason="unable to infer dtype") + elif source_library == "jax.numpy" and target_library == "torch": xfail(request, reason="casts int to float") elif source_library == "cupy" and target_library != "cupy": @@ -260,7 +271,6 @@ def test_asarray_cross_library(source_library, target_library, request): assert b.dtype == tgt_lib.int32 - @pytest.mark.parametrize("library", wrapped_libraries) def test_asarray_copy(library): # Note, we have this test here because the test suite currently doesn't @@ -270,100 +280,99 @@ def test_asarray_copy(library): xp = import_(library, wrapper=True) asarray = xp.asarray is_lib_func = globals()[is_array_functions[library]] - all = xp.all if library != 'dask.array' else lambda x: xp.all(x).compute() - - if library == 'numpy' and xp.__version__[0] < '2' and not hasattr(xp, '_CopyMode') : - supports_copy_false_other_ns = False - supports_copy_false_same_ns = False - elif library == 'cupy': - supports_copy_false_other_ns = False - supports_copy_false_same_ns = False - elif library == 'dask.array': - supports_copy_false_other_ns = False - supports_copy_false_same_ns = True - else: - supports_copy_false_other_ns = True - supports_copy_false_same_ns = True a = asarray([1]) b = asarray(a, copy=True) assert is_lib_func(b) a[0] = 0 - assert all(b[0] == 1) - assert all(a[0] == 0) + assert b[0] == 1 + assert a[0] == 0 a = asarray([1]) - if supports_copy_false_same_ns: - b = asarray(a, copy=False) - assert is_lib_func(b) - a[0] = 0 - assert all(b[0] == 0) - else: - pytest.raises(NotImplementedError, lambda: asarray(a, copy=False)) - a = asarray([1]) - if supports_copy_false_same_ns: - pytest.raises(ValueError, lambda: asarray(a, copy=False, - dtype=xp.float64)) - else: - pytest.raises(NotImplementedError, lambda: asarray(a, copy=False, dtype=xp.float64)) + # Test copy=False within the same namespace + b = asarray(a, copy=False) + assert is_lib_func(b) + a[0] = 0 + assert b[0] == 0 + with pytest.raises(ValueError): + asarray(a, copy=False, dtype=xp.float64) + # copy=None defaults to False when possible a = asarray([1]) b = asarray(a, copy=None) assert is_lib_func(b) a[0] = 0 - assert all(b[0] == 0) + assert b[0] == 0 + # copy=None defaults to True when impossible a = asarray([1.0], dtype=xp.float32) assert a.dtype == xp.float32 b = asarray(a, dtype=xp.float64, copy=None) assert is_lib_func(b) assert b.dtype == xp.float64 a[0] = 0.0 - assert all(b[0] == 1.0) + assert b[0] == 1.0 + # copy=None defaults to False when possible a = asarray([1.0], dtype=xp.float64) assert a.dtype == xp.float64 b = asarray(a, dtype=xp.float64, copy=None) assert is_lib_func(b) assert b.dtype == xp.float64 a[0] = 0.0 - assert all(b[0] == 0.0) + assert b[0] == 0.0 # Python built-in types for obj in [True, 0, 0.0, 0j, [0], [[0]]]: - asarray(obj, copy=True) # No error - asarray(obj, copy=None) # No error - if supports_copy_false_other_ns: - pytest.raises(ValueError, lambda: asarray(obj, copy=False)) - else: - pytest.raises(NotImplementedError, lambda: asarray(obj, copy=False)) + asarray(obj, copy=True) # No error + asarray(obj, copy=None) # No error + + with pytest.raises(ValueError): + asarray(obj, copy=False) # Use the standard library array to test the buffer protocol - a = array.array('f', [1.0]) + a = array.array("f", [1.0]) b = asarray(a, copy=True) assert is_lib_func(b) a[0] = 0.0 - assert all(b[0] == 1.0) + assert b[0] == 1.0 - a = array.array('f', [1.0]) - if supports_copy_false_other_ns: + a = array.array("f", [1.0]) + if library in ("cupy", "dask.array"): + with pytest.raises(ValueError): + asarray(a, copy=False) + else: b = asarray(a, copy=False) assert is_lib_func(b) a[0] = 0.0 - assert all(b[0] == 0.0) - else: - pytest.raises(NotImplementedError, lambda: asarray(a, copy=False)) + assert b[0] == 0.0 - a = array.array('f', [1.0]) + a = array.array("f", [1.0]) b = asarray(a, copy=None) assert is_lib_func(b) a[0] = 0.0 - if library in ('cupy', 'dask.array'): + if library in ("cupy", "dask.array"): # A copy is required for libraries where the default device is not CPU # dask changed behaviour of copy=None in 2024.12 to copy; # this wrapper ensures the same behaviour in older versions too. # https://github.com/dask/dask/pull/11524/ - assert all(b[0] == 1.0) + assert b[0] == 1.0 else: - assert all(b[0] == 0.0) + # copy=None defaults to False when possible + assert b[0] == 0.0 + + +@pytest.mark.parametrize("library", ["numpy", "cupy", "torch"]) +def test_clip_out(library): + """Test non-standard out= parameter for clip() + + (see "Avoid Restricting Behavior that is Outside the Scope of the Standard" + in https://data-apis.org/array-api-compat/dev/special-considerations.html) + """ + xp = import_(library, wrapper=True) + x = xp.asarray([10, 20, 30]) + out = xp.zeros_like(x) + xp.clip(x, 15, 25, out=out) + expect = xp.asarray([15, 20, 25]) + assert xp.all(out == expect) diff --git a/tests/test_copies_or_views.py b/tests/test_copies_or_views.py new file mode 100644 index 00000000..1e564694 --- /dev/null +++ b/tests/test_copies_or_views.py @@ -0,0 +1,76 @@ +""" +A collection of tests to make sure that wrapped namespaces agree with the bare ones +on whether to return a view or a copy of inputs. +""" +import pytest +from ._helpers import import_, wrapped_libraries + + +FUNC_INPUTS = [ + # func_name, arr_input, dtype, scalar_value + ('abs', [1, 2], 'int8', 3), + ('abs', [1, 2], 'float32', 3.), + ('ceil', [1, 2], 'int8', 3), + ('clip', [1, 2], 'int8', 3), + ('conj', [1, 2], 'int8', 3), + ('floor', [1, 2], 'int8', 3), + ('imag', [1j, 2j], 'complex64', 3), + ('positive', [1, 2], 'int8', 3), + ('real', [1., 2.], 'float32', 3.), + ('round', [1, 2], 'int8', 3), + ('sign', [0, 0], 'float32', 3), + ('trunc', [1, 2], 'int8', 3), + ('trunc', [1, 2], 'float32', 3), +] + + +def ensure_unary(func, arr): + """Make a trivial unary function from func.""" + if func.__name__ == 'clip': + return lambda x: func(x, arr[0], arr[1]) + return func + + +def is_view(func, a, value): + """Apply `func`, mutate the output; does the input change?""" + b = func(a) + b[0] = value + return a[0] == value + + +@pytest.mark.parametrize('xp_name', wrapped_libraries + ['array_api_strict']) +@pytest.mark.parametrize('inputs', FUNC_INPUTS, ids=[inp[0] for inp in FUNC_INPUTS]) +def test_view_or_copy(inputs, xp_name): + bare_xp = import_(xp_name, wrapper=False) + wrapped_xp = import_(xp_name, wrapper=True) + + func_name, arr_input, dtype_str, value = inputs + dtype = getattr(bare_xp, dtype_str) + + bare_func = getattr(bare_xp, func_name) + bare_func = ensure_unary(bare_func, arr_input) + + wrapped_func = getattr(wrapped_xp, func_name) + wrapped_func = ensure_unary(wrapped_func, arr_input) + + # bare namespace: mutate the output, does the input change? + a = bare_xp.asarray(arr_input, dtype=dtype) + is_view_bare = is_view(bare_func, a, value) + + # wrapped namespace: mutate the output, does the input change? + a1 = wrapped_xp.asarray(arr_input, dtype=dtype) + is_view_wrapped = is_view(wrapped_func, a1, value) + + assert is_view_bare == is_view_wrapped + + +@pytest.mark.parametrize('xp_name', wrapped_libraries + ['array_api_strict']) +def test_clip_none(xp_name): + xp = import_(xp_name, wrapper=True) + + if xp_name == 'array_api_strict' and xp.__version__ < "2.5": + # https://github.com/data-apis/array-api-strict/pull/180 + pytest.xfail("clip(x) was only fixed in -strict == 2.5") + + x = xp.arange(8) + assert not is_view(xp.clip, x, 42) diff --git a/tests/test_cupy.py b/tests/test_cupy.py new file mode 100644 index 00000000..4745b983 --- /dev/null +++ b/tests/test_cupy.py @@ -0,0 +1,45 @@ +import pytest +from array_api_compat import device, to_device + +xp = pytest.importorskip("array_api_compat.cupy") +from cupy.cuda import Stream + + +@pytest.mark.parametrize( + "make_stream", + [ + lambda: Stream(), + lambda: Stream(non_blocking=True), + lambda: Stream(null=True), + lambda: Stream(ptds=True), + ], +) +def test_to_device_with_stream(make_stream): + devices = xp.__array_namespace_info__().devices() + + a = xp.asarray([1, 2, 3]) + for dev in devices: + # Streams are device-specific and must be created within + # the context of the device... + with dev: + stream = make_stream() + # ... however, to_device() does not need to be inside the + # device context. + b = to_device(a, dev, stream=stream) + assert device(b) == dev + + +def test_to_device_with_dlpack_stream(): + devices = xp.__array_namespace_info__().devices() + + a = xp.asarray([1, 2, 3]) + for dev in devices: + # Streams are device-specific and must be created within + # the context of the device... + with dev: + s1 = Stream() + + # ... however, to_device() does not need to be inside the + # device context. + b = to_device(a, dev, stream=s1.ptr) + assert device(b) == dev diff --git a/tests/test_dask.py b/tests/test_dask.py index be2b1e39..4200e5b7 100644 --- a/tests/test_dask.py +++ b/tests/test_dask.py @@ -1,10 +1,14 @@ +import sys from contextlib import contextmanager -import array_api_strict -import dask import numpy as np import pytest -import dask.array as da + +try: + import dask + import dask.array as da +except ImportError: + pytestmark = pytest.skip(allow_module_level=True, reason="dask not found") from array_api_compat import array_namespace @@ -164,12 +168,17 @@ def test_sort_argsort_chunk_size(xp, func, shape, chunks): ) +@pytest.mark.skipif( + sys.version_info.major*100 + sys.version_info.minor < 312, + reason="dask interop requires numpy >= 3.12" +) @pytest.mark.parametrize("func", ["sort", "argsort"]) def test_sort_argsort_meta(xp, func): """Test meta-namespace other than numpy""" - typ = type(array_api_strict.asarray(0)) + mxp = pytest.importorskip("array_api_strict") + typ = type(mxp.asarray(0)) a = da.random.random(10) - b = a.map_blocks(array_api_strict.asarray) + b = a.map_blocks(mxp.asarray) assert isinstance(b._meta, typ) c = getattr(xp, func)(b) assert isinstance(c._meta, typ) diff --git a/tests/test_isdtype.py b/tests/test_isdtype.py index 6ad45d4c..92e45613 100644 --- a/tests/test_isdtype.py +++ b/tests/test_isdtype.py @@ -61,7 +61,7 @@ def isdtype_(dtype_, kind): res = dtype_categories[kind](dtype_) else: res = dtype_ == kind - assert type(res) is bool # noqa: E721 + assert type(res) is bool return res @pytest.mark.parametrize("library", wrapped_libraries) diff --git a/tests/test_jax.py b/tests/test_jax.py index e33cec02..322d0223 100644 --- a/tests/test_jax.py +++ b/tests/test_jax.py @@ -1,15 +1,26 @@ -import jax -import jax.numpy as jnp from numpy.testing import assert_equal import pytest -from array_api_compat import device, to_device +from array_api_compat import ( + device, + to_device, + is_jax_array, + is_lazy_array, + is_array_api_obj, + is_writeable_array, +) + +try: + import jax + import jax.numpy as jnp +except ImportError: + pytestmark = pytest.skip(allow_module_level=True, reason="jax not found") HAS_JAX_0_4_31 = jax.__version__ >= "0.4.31" @pytest.mark.parametrize( - "func", + "func", [ lambda x: jnp.zeros(1, device=device(x)), lambda x: jnp.zeros_like(jnp.ones(1, device=device(x))), @@ -22,7 +33,7 @@ ), ), lambda x: to_device(jnp.zeros(1), device(x)), - ] + ], ) def test_device_jit(func): # Test work around to https://github.com/jax-ml/jax/issues/26000 @@ -32,3 +43,15 @@ def test_device_jit(func): x = jnp.ones(1) assert_equal(func(x), jnp.asarray([0])) assert_equal(jax.jit(func)(x), jnp.asarray([0])) + + +def test_inside_jit(): + # Test if jax arrays are handled correctly inside jax.jit. + # Jax tracers are not a subclass of jax.Array from 0.8.2 on. We explicitly test that + # tracers are handled appropriately. For limitations, see is_jax_array() docstring. + # Reference issue: https://github.com/data-apis/array-api-compat/issues/368 + x = jnp.asarray([1, 2, 3]) + assert jax.jit(is_jax_array)(x) + assert jax.jit(is_array_api_obj)(x) + assert not jax.jit(is_writeable_array)(x) + assert jax.jit(is_lazy_array)(x) diff --git a/tests/test_no_dependencies.py b/tests/test_no_dependencies.py index a1fdf731..624f8971 100644 --- a/tests/test_no_dependencies.py +++ b/tests/test_no_dependencies.py @@ -17,7 +17,7 @@ class Array: # Dummy array namespace that doesn't depend on any array library def __array_namespace__(self, api_version=None): class Namespace: - pass + __name__: str = "foobar" return Namespace() def _test_dependency(mod): @@ -38,11 +38,11 @@ def _test_dependency(mod): assert not is_mod_array(a) assert mod not in sys.modules - is_array_api_obj = getattr(array_api_compat, "is_array_api_obj") + is_array_api_obj = array_api_compat.is_array_api_obj assert is_array_api_obj(a) assert mod not in sys.modules - array_namespace = getattr(array_api_compat, "array_namespace") + array_namespace = array_api_compat.array_namespace array_namespace(Array()) assert mod not in sys.modules diff --git a/tests/test_numpy.py b/tests/test_numpy.py new file mode 100644 index 00000000..a139d428 --- /dev/null +++ b/tests/test_numpy.py @@ -0,0 +1,19 @@ +"""Test "unspecified" behavior which we cannot easily test in the Array API test suite. +""" +import warnings +import pytest + +try: + import numpy as np +except ImportError: + pytestmark = pytest.skip(allow_module_level=True, reason="numpy not found") + +from array_api_compat import is_array_api_obj + +def test_matrix_is_not_array_api_obj(): + assert is_array_api_obj(np.asarray(3)) + assert is_array_api_obj(np.float64(3)) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", PendingDeprecationWarning) + assert not is_array_api_obj(np.matrix(3)) diff --git a/tests/test_torch.py b/tests/test_torch.py new file mode 100644 index 00000000..3d6ebc46 --- /dev/null +++ b/tests/test_torch.py @@ -0,0 +1,180 @@ +"""Test "unspecified" behavior which we cannot easily test in the Array API test suite. +""" +import itertools + +import pytest + +try: + import torch +except ImportError: + pytestmark = pytest.skip(allow_module_level=True, reason="pytorch not found") + +from array_api_compat import torch as xp + + +class TestResultType: + def test_empty(self): + with pytest.raises(ValueError): + xp.result_type() + + def test_one_arg(self): + for x in [1, 1.0, 1j, '...', None]: + with pytest.raises((ValueError, AttributeError)): + xp.result_type(x) + + for x in [xp.float32, xp.int64, torch.complex64]: + assert xp.result_type(x) == x + + for x in [xp.asarray(True, dtype=xp.bool), xp.asarray(1, dtype=xp.complex64)]: + assert xp.result_type(x) == x.dtype + + def test_two_args(self): + # Only include here things "unspecified" in the spec + + # scalar, tensor or tensor,tensor + for x, y in [ + (1., 1j), + (1j, xp.arange(3)), + (True, xp.asarray(3.)), + (xp.ones(3) == 1, 1j*xp.ones(3)), + ]: + assert xp.result_type(x, y) == torch.result_type(x, y) + + # dtype, scalar + for x, y in [ + (1j, xp.int64), + (True, xp.float64), + ]: + assert xp.result_type(x, y) == torch.result_type(x, xp.empty([], dtype=y)) + + # dtype, dtype + for x, y in [ + (xp.bool, xp.complex64) + ]: + xt, yt = xp.empty([], dtype=x), xp.empty([], dtype=y) + assert xp.result_type(x, y) == torch.result_type(xt, yt) + + def test_multi_arg(self): + torch.set_default_dtype(torch.float32) + + args = [1., 5, 3, torch.asarray([3], dtype=torch.float16), 5, 6, 1.] + assert xp.result_type(*args) == torch.float16 + + args = [1, 2, 3j, xp.arange(3, dtype=xp.float32), 4, 5, 6] + assert xp.result_type(*args) == xp.complex64 + + args = [1, 2, 3j, xp.float64, 4, 5, 6] + assert xp.result_type(*args) == xp.complex128 + + args = [1, 2, 3j, xp.float64, 4, xp.asarray(3, dtype=xp.int16), 5, 6, False] + assert xp.result_type(*args) == xp.complex128 + + i64 = xp.ones(1, dtype=xp.int64) + f16 = xp.ones(1, dtype=xp.float16) + for i in itertools.permutations([i64, f16, 1.0, 1.0]): + assert xp.result_type(*i) == xp.float16, f"{i}" + + with pytest.raises(ValueError): + xp.result_type(1, 2, 3, 4) + + + @pytest.mark.parametrize("default_dt", ['float32', 'float64']) + @pytest.mark.parametrize("dtype_a", + (xp.int32, xp.int64, xp.float32, xp.float64, xp.complex64, xp.complex128) + ) + @pytest.mark.parametrize("dtype_b", + (xp.int32, xp.int64, xp.float32, xp.float64, xp.complex64, xp.complex128) + ) + def test_gh_273(self, default_dt, dtype_a, dtype_b): + # Regression test for https://github.com/data-apis/array-api-compat/issues/273 + + try: + prev_default = torch.get_default_dtype() + default_dtype = getattr(torch, default_dt) + torch.set_default_dtype(default_dtype) + + a = xp.asarray([2, 1], dtype=dtype_a) + b = xp.asarray([1, -1], dtype=dtype_b) + dtype_1 = xp.result_type(a, b, 1.0) + dtype_2 = xp.result_type(b, a, 1.0) + assert dtype_1 == dtype_2 + finally: + torch.set_default_dtype(prev_default) + + +def test_clip_vmap(): + # https://github.com/data-apis/array-api-compat/issues/350 + def apply_clip_compat(a): + return xp.clip(a, min=0, max=30) + + a = xp.asarray([[5.1, 2.0, 64.1, -1.5]]) + + ref = apply_clip_compat(a) + v1 = torch.vmap(apply_clip_compat) + assert xp.all(v1(a) == ref) + + +def test_meshgrid(): + """Verify that array_api_compat.torch.meshgrid defaults to indexing='xy', and supports passing no arrays.""" + + x, y = xp.asarray([1, 2]), xp.asarray([4]) + + X, Y = xp.meshgrid(x, y) + + # output of torch.meshgrid(x, y, indexing='xy') -- indexing='ij' is different + X_xy, Y_xy = xp.asarray([[1, 2]]), xp.asarray([[4, 4]]) + + assert X.shape == X_xy.shape + assert xp.all(X == X_xy) + + assert Y.shape == Y_xy.shape + assert xp.all(Y == Y_xy) + + # repeat with an explicit indexing + X, Y = xp.meshgrid(x, y, indexing='ij') + + # output of torch.meshgrid(x, y, indexing='ij') + X_ij, Y_ij = xp.asarray([[1], [2]]), xp.asarray([[4], [4]]) + + assert X.shape == X_ij.shape + assert xp.all(X == X_ij) + + assert Y.shape == Y_ij.shape + assert xp.all(Y == Y_ij) + + assert not xp.meshgrid() + + +def test_argsort_stable(): + """Verify that argsort defaults to a stable sort.""" + # Bare pytorch defaults to an unstable sort, and the array_api_compat wrapper + # enforces the stable=True default. + # cf https://github.com/data-apis/array-api-compat/pull/356 and + # https://github.com/data-apis/array-api-tests/pull/390#issuecomment-3452868329 + + t = xp.zeros(50) # should be >16 + assert xp.all(xp.argsort(t) == xp.arange(50)) + + +def test_round(): + """Verify the out= argument of xp.round with complex inputs.""" + x = torch.as_tensor([1.23456786]*3) + 3.456789j + o = torch.empty(3, dtype=torch.complex64) + r = xp.round(x, decimals=1, out=o) + assert xp.all(r == o) + assert r is o + + +def test_dynamo_array_namespace(): + """Check that torch.compiling array_namespace does not incur graph breaks.""" + from array_api_compat import array_namespace + + def foo(x): + xp = array_namespace(x) + return xp.multiply(x, x) + + bar = torch.compile(fullgraph=True)(foo) + + x = torch.arange(3) + y = bar(x) + assert xp.all(y == x**2) diff --git a/tests/test_vendoring.py b/tests/test_vendoring.py index 70083b49..8b561551 100644 --- a/tests/test_vendoring.py +++ b/tests/test_vendoring.py @@ -16,11 +16,13 @@ def test_vendoring_cupy(): def test_vendoring_torch(): + pytest.importorskip("torch") from vendor_test import uses_torch uses_torch._test_torch() def test_vendoring_dask(): + pytest.importorskip("dask") from vendor_test import uses_dask uses_dask._test_dask() diff --git a/torch-skips.txt b/torch-skips.txt index cbf7235b..e69de29b 100644 --- a/torch-skips.txt +++ b/torch-skips.txt @@ -1,11 +0,0 @@ -# These tests cause a core dump on CI, so we have to skip them entirely -array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__ifloordiv__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__ifloordiv__(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[floor_divide(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__imod__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__imod__(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__mod__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__mod__(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[remainder(x1, x2)] diff --git a/torch-xfails.txt b/torch-xfails.txt index 2899bdb3..22aeed8f 100644 --- a/torch-xfails.txt +++ b/torch-xfails.txt @@ -6,7 +6,7 @@ # Indexing does not support negative step array_api_tests/test_array_object.py::test_getitem array_api_tests/test_array_object.py::test_setitem -# Masking doesn't suport 0 dimensions in the mask +# Masking doesn't support 0 dimensions in the mask array_api_tests/test_array_object.py::test_getitem_masking # Overflow error from large inputs @@ -29,6 +29,10 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_divide[__trued array_api_tests/test_operators_and_elementwise_functions.py::test_equal[__eq__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_not_equal[__ne__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x, s)] +array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__ifloordiv__(x, s)] +array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__ifloordiv__(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[floor_divide(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_greater[__gt__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[__ge__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_less[__lt__(x1, x2)] @@ -104,6 +108,21 @@ array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity array_api_tests/test_special_cases.py::test_iop[__imod__(x1_i is +0 and x2_i < 0) -> -0] array_api_tests/test_special_cases.py::test_iop[__imod__(x1_i is -0 and x2_i > 0) -> +0] +# complex cases +array_api_tests/test_special_cases.py::test_unary[acos((real(x_i) is +0 or real(x_i) == -0) and imag(x_i) is +0) -> \u03c0/2 - 0j] +array_api_tests/test_special_cases.py::test_unary[expm1((real(x_i) is +0 or real(x_i) == -0) and imag(x_i) is +0) -> 0 + 0j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is +infinity and imag(x_i) is +0) -> +infinity + 0j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is -infinity and imag(x_i) is +infinity) -> -1 + 0j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is +infinity and imag(x_i) is +infinity) -> infinity + NaN j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is -infinity and imag(x_i) is NaN) -> -1 + 0j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is +infinity and imag(x_i) is NaN) -> infinity + NaN j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is NaN and imag(x_i) is +0) -> NaN + 0j] +array_api_tests/test_special_cases.py::test_unary[log1p(isfinite(real(x_i)) and imag(x_i) is +infinity) -> +infinity + \u03c0j/2] +array_api_tests/test_special_cases.py::test_unary[log1p((real(x_i) is +infinity or real(x_i) == -infinity) and imag(x_i) is NaN) -> +infinity + NaN j] +array_api_tests/test_special_cases.py::test_unary[log1p(real(x_i) is NaN and imag(x_i) is +infinity) -> +infinity + NaN j] +array_api_tests/test_special_cases.py::test_unary[sqrt(real(x_i) is +infinity and isfinite(imag(x_i)) and imag(x_i) > 0) -> +0 + infinity j] +array_api_tests/test_special_cases.py::test_unary[tanh(real(x_i) is +infinity and isfinite(imag(x_i)) and imag(x_i) > 0) -> 1 + 0j] + # Float correction is not supported by pytorch # (https://github.com/data-apis/array-api-tests/issues/168) array_api_tests/test_statistical_functions.py::test_std @@ -111,28 +130,43 @@ array_api_tests/test_statistical_functions.py::test_var # These functions do not yet support complex numbers -array_api_tests/test_operators_and_elementwise_functions.py::test_round array_api_tests/test_set_functions.py::test_unique_counts array_api_tests/test_set_functions.py::test_unique_values +# finfo/iinfo.dtype is a string instead of a dtype +array_api_tests/test_data_type_functions.py::test_finfo_dtype +array_api_tests/test_data_type_functions.py::test_iinfo_dtype + # 2023.12 support -array_api_tests/test_has_names.py::test_has_names[manipulation-repeat] -array_api_tests/test_manipulation_functions.py::test_repeat -array_api_tests/test_signatures.py::test_func_signature[repeat] # Argument 'device' missing from signature array_api_tests/test_signatures.py::test_func_signature[from_dlpack] # Argument 'max_version' missing from signature array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__] - -# 2024.12 support -array_api_tests/test_signatures.py::test_func_signature[bitwise_and] -array_api_tests/test_signatures.py::test_func_signature[bitwise_left_shift] -array_api_tests/test_signatures.py::test_func_signature[bitwise_or] -array_api_tests/test_signatures.py::test_func_signature[bitwise_right_shift] -array_api_tests/test_signatures.py::test_func_signature[bitwise_xor] -array_api_tests/test_signatures.py::test_array_method_signature[__and__] -array_api_tests/test_signatures.py::test_array_method_signature[__lshift__] -array_api_tests/test_signatures.py::test_array_method_signature[__or__] -array_api_tests/test_signatures.py::test_array_method_signature[__rshift__] -array_api_tests/test_signatures.py::test_array_method_signature[__xor__] +# 2025.12 support + +# broadcast_shapes emits a RuntimeError where the spec says ValueError +array_api_tests/test_data_type_functions.py::TestBroadcastShapes::test_error + + +# 2024.12 support: binary functions reject python scalar arguments +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[atan2] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[copysign] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[divide] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[hypot] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[logaddexp] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[maximum] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[minimum] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[nextafter] + +# https://github.com/pytorch/pytorch/issues/149815 +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[equal] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[not_equal] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[less] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[less_equal] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[greater] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[greater_equal] + +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_bool[logical_and] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_bool[logical_or] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_bool[logical_xor]