From cac1443942212f872909fc6b6f2b610d8a5a6b84 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Tue, 4 Mar 2025 10:22:14 +0100 Subject: [PATCH 001/151] DOC: add a changelog for the 1.11.1 release --- docs/changelog.md | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/docs/changelog.md b/docs/changelog.md index 1de11606..bdf5f9e1 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -1,6 +1,27 @@ # Changelog -## 1.11.0 (2025-XX-XX) +## 1.11.1 (2025-03-04) + +This is a bugfix release with no new features compared to version 1.11. + +### Major Changes + +- fix `count_nonzero` wrappers: work around the lack of the `keepdims` argument in + several array libraries (torch, dask, cupy); work around numpy returning python + ints in for some input combinations. + +### Minor Changes + +- runnings self-tests does not require all array libraries. Missing libraries are + skipped. + +The following users contributed to this release: + +Evgeni Burovski +Guido Imperiale + + +## 1.11.0 (2025-02-27) ### Major Changes From e14754ba0fe4c4cd51b6f45bb11a3c6609be3b5c Mon Sep 17 00:00:00 2001 From: Guido Imperiale Date: Wed, 5 Mar 2025 10:04:21 +0000 Subject: [PATCH 002/151] BUG: `clip(out=...)` is broken (#261) reviewed at https://github.com/data-apis/array-api-compat/pull/261 --- array_api_compat/common/_aliases.py | 26 +++++++++++++------------- tests/test_common.py | 15 +++++++++++++++ 2 files changed, 28 insertions(+), 13 deletions(-) diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 98b8e425..d7e8ef2d 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -12,7 +12,7 @@ from typing import NamedTuple import inspect -from ._helpers import array_namespace, _check_device, device, is_torch_array, is_cupy_namespace +from ._helpers import array_namespace, _check_device, device, is_cupy_namespace # These functions are modified from the NumPy versions. @@ -368,23 +368,23 @@ def _isscalar(a): if type(max) is int and max >= wrapped_xp.iinfo(x.dtype).max: max = None + dev = device(x) if out is None: - out = wrapped_xp.asarray(xp.broadcast_to(x, result_shape), - copy=True, device=device(x)) + out = wrapped_xp.empty(result_shape, dtype=x.dtype, device=dev) + out[()] = x + if min is not None: - if is_torch_array(x) and x.dtype == xp.float64 and _isscalar(min): - # Avoid loss of precision due to torch defaulting to float32 - min = wrapped_xp.asarray(min, dtype=xp.float64) - a = xp.broadcast_to(wrapped_xp.asarray(min, device=device(x)), result_shape) + a = wrapped_xp.asarray(min, dtype=x.dtype, device=dev) + a = xp.broadcast_to(a, result_shape) ia = (out < a) | xp.isnan(a) - # torch requires an explicit cast here - out[ia] = wrapped_xp.astype(a[ia], out.dtype) + out[ia] = a[ia] + if max is not None: - if is_torch_array(x) and x.dtype == xp.float64 and _isscalar(max): - max = wrapped_xp.asarray(max, dtype=xp.float64) - b = xp.broadcast_to(wrapped_xp.asarray(max, device=device(x)), result_shape) + b = wrapped_xp.asarray(max, dtype=x.dtype, device=dev) + b = xp.broadcast_to(b, result_shape) ib = (out > b) | xp.isnan(b) - out[ib] = wrapped_xp.astype(b[ib], out.dtype) + out[ib] = b[ib] + # Return a scalar for 0-D return out[()] diff --git a/tests/test_common.py b/tests/test_common.py index 32876e69..f86e0936 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -367,3 +367,18 @@ def test_asarray_copy(library): assert all(b[0] == 1.0) else: assert all(b[0] == 0.0) + + +@pytest.mark.parametrize("library", ["numpy", "cupy", "torch"]) +def test_clip_out(library): + """Test non-standard out= parameter for clip() + + (see "Avoid Restricting Behavior that is Outside the Scope of the Standard" + in https://data-apis.org/array-api-compat/dev/special-considerations.html) + """ + xp = import_(library, wrapper=True) + x = xp.asarray([10, 20, 30]) + out = xp.zeros_like(x) + xp.clip(x, 15, 25, out=out) + expect = xp.asarray([15, 20, 25]) + assert xp.all(out == expect) From 3f14b184dbbd47e81d1d47514c7b5a2772969b81 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sat, 15 Mar 2025 12:06:31 +0000 Subject: [PATCH 003/151] add torch xfails for scalars in binary functions --- torch-xfails.txt | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/torch-xfails.txt b/torch-xfails.txt index 2899bdb3..6e8f7dc6 100644 --- a/torch-xfails.txt +++ b/torch-xfails.txt @@ -124,7 +124,6 @@ array_api_tests/test_signatures.py::test_func_signature[from_dlpack] # Argument 'max_version' missing from signature array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__] - # 2024.12 support array_api_tests/test_signatures.py::test_func_signature[bitwise_and] array_api_tests/test_signatures.py::test_func_signature[bitwise_left_shift] @@ -136,3 +135,22 @@ array_api_tests/test_signatures.py::test_array_method_signature[__lshift__] array_api_tests/test_signatures.py::test_array_method_signature[__or__] array_api_tests/test_signatures.py::test_array_method_signature[__rshift__] array_api_tests/test_signatures.py::test_array_method_signature[__xor__] + +# 2024.12 support: binary functions reject python scalar arguments +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[atan2] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[copysign] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[divide] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[hypot] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[logaddexp] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[maximum] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[minimum] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[equal] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[neq] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[less] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[les_equal] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[greater] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[greater_equal] + +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_bool[logical_and] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_bool[logical_or] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_bool[logical_xor] From 05ade6738542ff556da4bbaa6ad4acd29290d989 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Magnus=20Dalen=20Kvalev=C3=A5g?= Date: Mon, 17 Mar 2025 17:20:59 +0100 Subject: [PATCH 004/151] Fix clipping float with python int for min and max --- array_api_compat/common/_aliases.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index d7e8ef2d..35262d3a 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -363,10 +363,11 @@ def _isscalar(a): # At least handle the case of Python integers correctly (see # https://github.com/numpy/numpy/pull/26892). - if type(min) is int and min <= wrapped_xp.iinfo(x.dtype).min: - min = None - if type(max) is int and max >= wrapped_xp.iinfo(x.dtype).max: - max = None + if wrapped_xp.isdtype(x.dtype, "integral"): + if type(min) is int and min <= wrapped_xp.iinfo(x.dtype).min: + min = None + if type(max) is int and max >= wrapped_xp.iinfo(x.dtype).max: + max = None dev = device(x) if out is None: From 58d8037f372113dfc4d0b36b3740a8b34ed85c7f Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 17 Mar 2025 19:29:07 +0100 Subject: [PATCH 005/151] BUG: torch: fix `result_type` with python scalars 1. Allow inputs to be arrays or dtypes or python scalars 2. Keep the pytorch-specific additions, e.g. `result_type(int, float) -> float`, `result_type(scalar, scalar) -> dtype` which are unspecified in the standard 3. Since pytorch only defines a binary `result_type` function, add a version with multiple inputs. The latter is a bit tricky because we want to - keep allowing "unspecified" behaviors - keep standard-allowed promotions compliant - (preferably) make result_type independent on the argument order The latter is important because of `int,float->float` promotions which break associativity. So what we do, we always promote all scalars after all array/dtype arguments. --- array_api_compat/torch/_aliases.py | 45 ++++++++++---- tests/test_torch.py | 98 ++++++++++++++++++++++++++++++ 2 files changed, 130 insertions(+), 13 deletions(-) create mode 100644 tests/test_torch.py diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index b4786320..4b727f1c 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -1,6 +1,6 @@ from __future__ import annotations -from functools import wraps as _wraps +from functools import reduce as _reduce, wraps as _wraps from builtins import all as _builtin_all, any as _builtin_any from ..common import _aliases @@ -124,25 +124,43 @@ def _fix_promotion(x1, x2, only_scalar=True): def result_type(*arrays_and_dtypes: Union[array, Dtype, bool, int, float, complex]) -> Dtype: - if len(arrays_and_dtypes) == 0: - raise TypeError("At least one array or dtype must be provided") - if len(arrays_and_dtypes) == 1: + num = len(arrays_and_dtypes) + + if num == 0: + raise ValueError("At least one array or dtype must be provided") + + elif num == 1: x = arrays_and_dtypes[0] if isinstance(x, torch.dtype): return x return x.dtype - if len(arrays_and_dtypes) > 2: - return result_type(arrays_and_dtypes[0], result_type(*arrays_and_dtypes[1:])) - x, y = arrays_and_dtypes - if isinstance(x, _py_scalars) or isinstance(y, _py_scalars): - return torch.result_type(x, y) + if num == 2: + x, y = arrays_and_dtypes + return _result_type(x, y) + + else: + # sort scalars so that they are treated last + scalars, others = [], [] + for x in arrays_and_dtypes: + if isinstance(x, _py_scalars): + scalars.append(x) + else: + others.append(x) + if not others: + raise ValueError("At least one array or dtype must be provided") + + # combine left-to-right + return _reduce(_result_type, others + scalars) - xdt = x.dtype if not isinstance(x, torch.dtype) else x - ydt = y.dtype if not isinstance(y, torch.dtype) else y - if (xdt, ydt) in _promotion_table: - return _promotion_table[xdt, ydt] +def _result_type(x, y): + if not (isinstance(x, _py_scalars) or isinstance(y, _py_scalars)): + xdt = x.dtype if not isinstance(x, torch.dtype) else x + ydt = y.dtype if not isinstance(y, torch.dtype) else y + + if (xdt, ydt) in _promotion_table: + return _promotion_table[xdt, ydt] # This doesn't result_type(dtype, dtype) for non-array API dtypes # because torch.result_type only accepts tensors. This does however, allow @@ -151,6 +169,7 @@ def result_type(*arrays_and_dtypes: Union[array, Dtype, bool, int, float, comple y = torch.tensor([], dtype=y) if isinstance(y, torch.dtype) else y return torch.result_type(x, y) + def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool: if not isinstance(from_, torch.dtype): from_ = from_.dtype diff --git a/tests/test_torch.py b/tests/test_torch.py new file mode 100644 index 00000000..75b3a136 --- /dev/null +++ b/tests/test_torch.py @@ -0,0 +1,98 @@ +"""Test "unspecified" behavior which we cannot easily test in the Array API test suite. +""" +import itertools + +import pytest +import torch + +from array_api_compat import torch as xp + + +class TestResultType: + def test_empty(self): + with pytest.raises(ValueError): + xp.result_type() + + def test_one_arg(self): + for x in [1, 1.0, 1j, '...', None]: + with pytest.raises((ValueError, AttributeError)): + xp.result_type(x) + + for x in [xp.float32, xp.int64, torch.complex64]: + assert xp.result_type(x) == x + + for x in [xp.asarray(True, dtype=xp.bool), xp.asarray(1, dtype=xp.complex64)]: + assert xp.result_type(x) == x.dtype + + def test_two_args(self): + # Only include here things "unspecified" in the spec + + # scalar, tensor or tensor,tensor + for x, y in [ + (1., 1j), + (1j, xp.arange(3)), + (True, xp.asarray(3.)), + (xp.ones(3) == 1, 1j*xp.ones(3)), + ]: + assert xp.result_type(x, y) == torch.result_type(x, y) + + # dtype, scalar + for x, y in [ + (1j, xp.int64), + (True, xp.float64), + ]: + assert xp.result_type(x, y) == torch.result_type(x, xp.empty([], dtype=y)) + + # dtype, dtype + for x, y in [ + (xp.bool, xp.complex64) + ]: + xt, yt = xp.empty([], dtype=x), xp.empty([], dtype=y) + assert xp.result_type(x, y) == torch.result_type(xt, yt) + + def test_multi_arg(self): + torch.set_default_dtype(torch.float32) + + args = [1., 5, 3, torch.asarray([3], dtype=torch.float16), 5, 6, 1.] + assert xp.result_type(*args) == torch.float16 + + args = [1, 2, 3j, xp.arange(3, dtype=xp.float32), 4, 5, 6] + assert xp.result_type(*args) == xp.complex64 + + args = [1, 2, 3j, xp.float64, 4, 5, 6] + assert xp.result_type(*args) == xp.complex128 + + args = [1, 2, 3j, xp.float64, 4, xp.asarray(3, dtype=xp.int16), 5, 6, False] + assert xp.result_type(*args) == xp.complex128 + + i64 = xp.ones(1, dtype=xp.int64) + f16 = xp.ones(1, dtype=xp.float16) + for i in itertools.permutations([i64, f16, 1.0, 1.0]): + assert xp.result_type(*i) == xp.float16, f"{i}" + + with pytest.raises(ValueError): + xp.result_type(1, 2, 3, 4) + + + @pytest.mark.parametrize("default_dt", ['float32', 'float64']) + @pytest.mark.parametrize("dtype_a", + (xp.int32, xp.int64, xp.float32, xp.float64, xp.complex64, xp.complex128) + ) + @pytest.mark.parametrize("dtype_b", + (xp.int32, xp.int64, xp.float32, xp.float64, xp.complex64, xp.complex128) + ) + def test_gh_273(self, default_dt, dtype_a, dtype_b): + # Regression test for https://github.com/data-apis/array-api-compat/issues/273 + + try: + prev_default = torch.get_default_dtype() + default_dtype = getattr(torch, default_dt) + torch.set_default_dtype(default_dtype) + + a = xp.asarray([2, 1], dtype=dtype_a) + b = xp.asarray([1, -1], dtype=dtype_b) + dtype_1 = xp.result_type(a, b, 1.0) + dtype_2 = xp.result_type(b, a, 1.0) + assert dtype_1 == dtype_2 + finally: + torch.set_default_dtype(prev_default) From 5473d84d5c36b23e091b880279c863c32f41b828 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Wed, 19 Mar 2025 16:28:49 +0100 Subject: [PATCH 006/151] TST: skip test_all --- tests/test_all.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_all.py b/tests/test_all.py index 10a2a95d..d2e9b768 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -16,6 +16,7 @@ import pytest +@pytest.mark.skip(reason="TODO: starts failing after adding test_torch.py in gh-277") @pytest.mark.parametrize("library", ["common"] + wrapped_libraries) def test_all(library): if library == "common": From c9622f965be76e947f2a7d3ecac827a90e67edfb Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Thu, 20 Mar 2025 13:11:49 +0100 Subject: [PATCH 007/151] DOC: add a changelog for 1.11.2 release --- docs/changelog.md | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/docs/changelog.md b/docs/changelog.md index bdf5f9e1..18928e98 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -1,5 +1,22 @@ # Changelog +## 1.11.2 (2025-03-20) + +This is a bugfix release with no new features compared to version 1.11. + +- fix the `result_type` wrapper for pytorch. Previously, `result_type` had multiple + issues with scalar arguments. +- fix several issues with `clip` wrappers. Previously, `clip` was failing to allow + behaviors which are unspecified by the 2024.12 standard but allowed by the array + libraries. + +The following users contributed to this release: + +Evgeni Burovski +Guido Imperiale +Magnus Dalen Kvalevåg + + ## 1.11.1 (2025-03-04) This is a bugfix release with no new features compared to version 1.11. From b8323760865a66ad03114ad80c1aa058df28dc98 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Thu, 20 Mar 2025 13:17:22 +0100 Subject: [PATCH 008/151] BLD: upper cap setuptools, do not error on deprecationwarnings --- .github/workflows/publish-package.yml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/publish-package.yml b/.github/workflows/publish-package.yml index 7733059d..6d88066d 100644 --- a/.github/workflows/publish-package.yml +++ b/.github/workflows/publish-package.yml @@ -41,13 +41,14 @@ jobs: - name: Install python-build and twine run: | - python -m pip install --upgrade pip setuptools + python -m pip install --upgrade pip "setuptools<=67" python -m pip install build twine python -m pip list - name: Build a wheel and a sdist run: | - PYTHONWARNINGS=error,default::DeprecationWarning python -m build . + #PYTHONWARNINGS=error,default::DeprecationWarning python -m build . + python -m build . - name: Verify the distribution run: twine check --strict dist/* From 1b0de51538deb7c21d0c268f36764a8589e40012 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Thu, 20 Mar 2025 13:33:39 +0100 Subject: [PATCH 009/151] REL: bump the version to 1.11.2 --- array_api_compat/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_compat/__init__.py b/array_api_compat/__init__.py index 60b37e97..96b061e7 100644 --- a/array_api_compat/__init__.py +++ b/array_api_compat/__init__.py @@ -17,6 +17,6 @@ this implementation for the default when working with NumPy arrays. """ -__version__ = '1.12.dev0' +__version__ = '1.11.2' from .common import * # noqa: F401, F403 From b1316cff516d147519a9c30f0e8327e5895598f4 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Thu, 20 Mar 2025 16:47:38 +0100 Subject: [PATCH 010/151] TST: skip tests of binary funcs w/scalar on older numpies NumPy < 2 fails to promote an empty f32 array with a scalar, returns an empty f64 array --- numpy-1-21-xfails.txt | 3 +++ numpy-1-26-xfails.txt | 3 +++ 2 files changed, 6 insertions(+) diff --git a/numpy-1-21-xfails.txt b/numpy-1-21-xfails.txt index 28c0e13a..7c7a0757 100644 --- a/numpy-1-21-xfails.txt +++ b/numpy-1-21-xfails.txt @@ -212,3 +212,6 @@ array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] + +# numpy < 2 bug: type promotion of asarray([], 'float32') and (np.finfo(float32).max + 1) -> float64 +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real diff --git a/numpy-1-26-xfails.txt b/numpy-1-26-xfails.txt index 80790534..57259b6f 100644 --- a/numpy-1-26-xfails.txt +++ b/numpy-1-26-xfails.txt @@ -66,3 +66,6 @@ array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] + +# numpy < 2 bug: type promotion of asarray([], 'float32') and (finfo(float32).max + 1) gives float64 not float32 +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real From 64ab7e26b86d0cd2d4cb544fdd39699a887823e8 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Fri, 21 Mar 2025 01:27:29 +0100 Subject: [PATCH 011/151] MAINT: update the version for 1.12.dev0 development --- array_api_compat/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_compat/__init__.py b/array_api_compat/__init__.py index 96b061e7..60b37e97 100644 --- a/array_api_compat/__init__.py +++ b/array_api_compat/__init__.py @@ -17,6 +17,6 @@ this implementation for the default when working with NumPy arrays. """ -__version__ = '1.11.2' +__version__ = '1.12.dev0' from .common import * # noqa: F401, F403 From 0080afed5b110c311cb88314d0370a2a3fcbefef Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Fri, 21 Mar 2025 12:02:54 +0100 Subject: [PATCH 012/151] Add a CuPy xfail CuPy 13.x follows NumPy 1.x without "weak scalars". In NumPy `result_type(int32, uint8, 1) != result_type(int32, uint8)` has been fixed in 2.x (or 1.x with set_promotion_state("weak"), so hopefully CuPy 14.x follows the suite, when released. Until then, just xfail the test. --- cupy-xfails.txt | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cupy-xfails.txt b/cupy-xfails.txt index 63e844cd..3d20d745 100644 --- a/cupy-xfails.txt +++ b/cupy-xfails.txt @@ -183,7 +183,7 @@ array_api_tests/test_manipulation_functions.py::test_repeat array_api_tests/test_signatures.py::test_func_signature[from_dlpack] array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__] -+# 2024.12 support +# 2024.12 support array_api_tests/test_signatures.py::test_func_signature[count_nonzero] array_api_tests/test_signatures.py::test_func_signature[bitwise_and] array_api_tests/test_signatures.py::test_func_signature[bitwise_left_shift] @@ -192,3 +192,5 @@ array_api_tests/test_signatures.py::test_func_signature[bitwise_right_shift] array_api_tests/test_signatures.py::test_func_signature[bitwise_xor] array_api_tests/test_special_cases.py::test_binary[nextafter(x1_i is +0 and x2_i is -0) -> -0] +# cupy 13.x follows numpy 1.x w/o weak promotion: result_type(int32, uint8, 1) != result_type(int32, uint8) +array_api_tests/test_data_type_functions.py::TestResultType::test_with_scalars From a5a1d8ba722da9b8a2783ccd63c0b60713932793 Mon Sep 17 00:00:00 2001 From: Guido Imperiale Date: Sat, 22 Mar 2025 17:34:57 +0000 Subject: [PATCH 013/151] TYP: Type annotations overhaul, part 1 (#257) * ENH: Type annotations overhaul * Re-add py.typed * code review * lint * asarray * fill_value * result_type * lint * Arrays don't need to support buffer protocol * bool is a subclass of int * reshape: copy kwarg is keyword-only * tensordot formatting * Reinstate explicit bool | complex --- array_api_compat/common/_aliases.py | 248 +++++++++++++----------- array_api_compat/common/_fft.py | 87 +++++---- array_api_compat/common/_helpers.py | 32 +-- array_api_compat/common/_linalg.py | 84 +++++--- array_api_compat/common/_typing.py | 16 +- array_api_compat/cupy/_aliases.py | 36 ++-- array_api_compat/cupy/_typing.py | 63 +++--- array_api_compat/dask/array/_aliases.py | 54 ++---- array_api_compat/dask/array/fft.py | 13 +- array_api_compat/dask/array/linalg.py | 25 +-- array_api_compat/numpy/_aliases.py | 41 ++-- array_api_compat/numpy/_typing.py | 63 +++--- array_api_compat/py.typed | 0 array_api_compat/torch/_aliases.py | 168 ++++++++-------- array_api_compat/torch/_typing.py | 4 + array_api_compat/torch/fft.py | 35 ++-- array_api_compat/torch/linalg.py | 28 ++- setup.py | 5 +- tests/test_all.py | 17 +- 19 files changed, 511 insertions(+), 508 deletions(-) create mode 100644 array_api_compat/py.typed create mode 100644 array_api_compat/torch/_typing.py diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 35262d3a..0d123b99 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -4,15 +4,11 @@ from __future__ import annotations -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from typing import Optional, Sequence, Tuple, Union - from ._typing import ndarray, Device, Dtype - -from typing import NamedTuple import inspect +from typing import NamedTuple, Optional, Sequence, Tuple, Union from ._helpers import array_namespace, _check_device, device, is_cupy_namespace +from ._typing import Array, Device, DType, Namespace # These functions are modified from the NumPy versions. @@ -24,29 +20,34 @@ def arange( stop: Optional[Union[int, float]] = None, step: Union[int, float] = 1, *, - xp, - dtype: Optional[Dtype] = None, + xp: Namespace, + dtype: Optional[DType] = None, device: Optional[Device] = None, - **kwargs -) -> ndarray: + **kwargs, +) -> Array: _check_device(xp, device) return xp.arange(start, stop=stop, step=step, dtype=dtype, **kwargs) def empty( shape: Union[int, Tuple[int, ...]], - xp, + xp: Namespace, *, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, - **kwargs -) -> ndarray: + **kwargs, +) -> Array: _check_device(xp, device) return xp.empty(shape, dtype=dtype, **kwargs) def empty_like( - x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, - **kwargs -) -> ndarray: + x: Array, + /, + xp: Namespace, + *, + dtype: Optional[DType] = None, + device: Optional[Device] = None, + **kwargs, +) -> Array: _check_device(xp, device) return xp.empty_like(x, dtype=dtype, **kwargs) @@ -55,37 +56,37 @@ def eye( n_cols: Optional[int] = None, /, *, - xp, + xp: Namespace, k: int = 0, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, **kwargs, -) -> ndarray: +) -> Array: _check_device(xp, device) return xp.eye(n_rows, M=n_cols, k=k, dtype=dtype, **kwargs) def full( shape: Union[int, Tuple[int, ...]], - fill_value: Union[int, float], - xp, + fill_value: complex, + xp: Namespace, *, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, **kwargs, -) -> ndarray: +) -> Array: _check_device(xp, device) return xp.full(shape, fill_value, dtype=dtype, **kwargs) def full_like( - x: ndarray, + x: Array, /, - fill_value: Union[int, float], + fill_value: complex, *, - xp, - dtype: Optional[Dtype] = None, + xp: Namespace, + dtype: Optional[DType] = None, device: Optional[Device] = None, **kwargs, -) -> ndarray: +) -> Array: _check_device(xp, device) return xp.full_like(x, fill_value, dtype=dtype, **kwargs) @@ -95,48 +96,58 @@ def linspace( /, num: int, *, - xp, - dtype: Optional[Dtype] = None, + xp: Namespace, + dtype: Optional[DType] = None, device: Optional[Device] = None, endpoint: bool = True, **kwargs, -) -> ndarray: +) -> Array: _check_device(xp, device) return xp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint, **kwargs) def ones( shape: Union[int, Tuple[int, ...]], - xp, + xp: Namespace, *, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, **kwargs, -) -> ndarray: +) -> Array: _check_device(xp, device) return xp.ones(shape, dtype=dtype, **kwargs) def ones_like( - x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, + x: Array, + /, + xp: Namespace, + *, + dtype: Optional[DType] = None, + device: Optional[Device] = None, **kwargs, -) -> ndarray: +) -> Array: _check_device(xp, device) return xp.ones_like(x, dtype=dtype, **kwargs) def zeros( shape: Union[int, Tuple[int, ...]], - xp, + xp: Namespace, *, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, **kwargs, -) -> ndarray: +) -> Array: _check_device(xp, device) return xp.zeros(shape, dtype=dtype, **kwargs) def zeros_like( - x: ndarray, /, xp, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, + x: Array, + /, + xp: Namespace, + *, + dtype: Optional[DType] = None, + device: Optional[Device] = None, **kwargs, -) -> ndarray: +) -> Array: _check_device(xp, device) return xp.zeros_like(x, dtype=dtype, **kwargs) @@ -150,23 +161,23 @@ def zeros_like( # Note that these named tuples aren't actually part of the standard namespace, # but I don't see any issue with exporting the names here regardless. class UniqueAllResult(NamedTuple): - values: ndarray - indices: ndarray - inverse_indices: ndarray - counts: ndarray + values: Array + indices: Array + inverse_indices: Array + counts: Array class UniqueCountsResult(NamedTuple): - values: ndarray - counts: ndarray + values: Array + counts: Array class UniqueInverseResult(NamedTuple): - values: ndarray - inverse_indices: ndarray + values: Array + inverse_indices: Array -def _unique_kwargs(xp): +def _unique_kwargs(xp: Namespace) -> dict[str, bool]: # Older versions of NumPy and CuPy do not have equal_nan. Rather than # trying to parse version numbers, just check if equal_nan is in the # signature. @@ -175,7 +186,7 @@ def _unique_kwargs(xp): return {'equal_nan': False} return {} -def unique_all(x: ndarray, /, xp) -> UniqueAllResult: +def unique_all(x: Array, /, xp: Namespace) -> UniqueAllResult: kwargs = _unique_kwargs(xp) values, indices, inverse_indices, counts = xp.unique( x, @@ -195,7 +206,7 @@ def unique_all(x: ndarray, /, xp) -> UniqueAllResult: ) -def unique_counts(x: ndarray, /, xp) -> UniqueCountsResult: +def unique_counts(x: Array, /, xp: Namespace) -> UniqueCountsResult: kwargs = _unique_kwargs(xp) res = xp.unique( x, @@ -208,7 +219,7 @@ def unique_counts(x: ndarray, /, xp) -> UniqueCountsResult: return UniqueCountsResult(*res) -def unique_inverse(x: ndarray, /, xp) -> UniqueInverseResult: +def unique_inverse(x: Array, /, xp: Namespace) -> UniqueInverseResult: kwargs = _unique_kwargs(xp) values, inverse_indices = xp.unique( x, @@ -223,7 +234,7 @@ def unique_inverse(x: ndarray, /, xp) -> UniqueInverseResult: return UniqueInverseResult(values, inverse_indices) -def unique_values(x: ndarray, /, xp) -> ndarray: +def unique_values(x: Array, /, xp: Namespace) -> Array: kwargs = _unique_kwargs(xp) return xp.unique( x, @@ -236,42 +247,42 @@ def unique_values(x: ndarray, /, xp) -> ndarray: # These functions have different keyword argument names def std( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, - correction: Union[int, float] = 0.0, # correction instead of ddof + correction: Union[int, float] = 0.0, # correction instead of ddof keepdims: bool = False, **kwargs, -) -> ndarray: +) -> Array: return xp.std(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs) def var( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, - correction: Union[int, float] = 0.0, # correction instead of ddof + correction: Union[int, float] = 0.0, # correction instead of ddof keepdims: bool = False, **kwargs, -) -> ndarray: +) -> Array: return xp.var(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs) # cumulative_sum is renamed from cumsum, and adds the include_initial keyword # argument def cumulative_sum( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, axis: Optional[int] = None, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, include_initial: bool = False, - **kwargs -) -> ndarray: + **kwargs, +) -> Array: wrapped_xp = array_namespace(x) # TODO: The standard is not clear about what should happen when x.ndim == 0. @@ -294,15 +305,15 @@ def cumulative_sum( def cumulative_prod( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, axis: Optional[int] = None, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, include_initial: bool = False, - **kwargs -) -> ndarray: + **kwargs, +) -> Array: wrapped_xp = array_namespace(x) if axis is None: @@ -325,17 +336,18 @@ def cumulative_prod( # The min and max argument names in clip are different and not optional in numpy, and type # promotion behavior is different. def clip( - x: ndarray, + x: Array, /, - min: Optional[Union[int, float, ndarray]] = None, - max: Optional[Union[int, float, ndarray]] = None, + min: Optional[Union[int, float, Array]] = None, + max: Optional[Union[int, float, Array]] = None, *, - xp, + xp: Namespace, # TODO: np.clip has other ufunc kwargs - out: Optional[ndarray] = None, -) -> ndarray: + out: Optional[Array] = None, +) -> Array: def _isscalar(a): return isinstance(a, (int, float, type(None))) + min_shape = () if _isscalar(min) else min.shape max_shape = () if _isscalar(max) else max.shape @@ -390,15 +402,19 @@ def _isscalar(a): return out[()] # Unlike transpose(), the axes argument to permute_dims() is required. -def permute_dims(x: ndarray, /, axes: Tuple[int, ...], xp) -> ndarray: +def permute_dims(x: Array, /, axes: Tuple[int, ...], xp: Namespace) -> Array: return xp.transpose(x, axes) # np.reshape calls the keyword argument 'newshape' instead of 'shape' -def reshape(x: ndarray, - /, - shape: Tuple[int, ...], - xp, copy: Optional[bool] = None, - **kwargs) -> ndarray: +def reshape( + x: Array, + /, + shape: Tuple[int, ...], + xp: Namespace, + *, + copy: Optional[bool] = None, + **kwargs, +) -> Array: if copy is True: x = x.copy() elif copy is False: @@ -410,9 +426,15 @@ def reshape(x: ndarray, # The descending keyword is new in sort and argsort, and 'kind' replaced with # 'stable' def argsort( - x: ndarray, /, xp, *, axis: int = -1, descending: bool = False, stable: bool = True, + x: Array, + /, + xp: Namespace, + *, + axis: int = -1, + descending: bool = False, + stable: bool = True, **kwargs, -) -> ndarray: +) -> Array: # Note: this keyword argument is different, and the default is different. # We set it in kwargs like this because numpy.sort uses kind='quicksort' # as the default whereas cupy.sort uses kind=None. @@ -435,9 +457,15 @@ def argsort( return res def sort( - x: ndarray, /, xp, *, axis: int = -1, descending: bool = False, stable: bool = True, + x: Array, + /, + xp: Namespace, + *, + axis: int = -1, + descending: bool = False, + stable: bool = True, **kwargs, -) -> ndarray: +) -> Array: # Note: this keyword argument is different, and the default is different. # We set it in kwargs like this because numpy.sort uses kind='quicksort' # as the default whereas cupy.sort uses kind=None. @@ -449,50 +477,51 @@ def sort( return res # nonzero should error for zero-dimensional arrays -def nonzero(x: ndarray, /, xp, **kwargs) -> Tuple[ndarray, ...]: +def nonzero(x: Array, /, xp: Namespace, **kwargs) -> Tuple[Array, ...]: if x.ndim == 0: raise ValueError("nonzero() does not support zero-dimensional arrays") return xp.nonzero(x, **kwargs) # ceil, floor, and trunc return integers for integer inputs -def ceil(x: ndarray, /, xp, **kwargs) -> ndarray: +def ceil(x: Array, /, xp: Namespace, **kwargs) -> Array: if xp.issubdtype(x.dtype, xp.integer): return x return xp.ceil(x, **kwargs) -def floor(x: ndarray, /, xp, **kwargs) -> ndarray: +def floor(x: Array, /, xp: Namespace, **kwargs) -> Array: if xp.issubdtype(x.dtype, xp.integer): return x return xp.floor(x, **kwargs) -def trunc(x: ndarray, /, xp, **kwargs) -> ndarray: +def trunc(x: Array, /, xp: Namespace, **kwargs) -> Array: if xp.issubdtype(x.dtype, xp.integer): return x return xp.trunc(x, **kwargs) # linear algebra functions -def matmul(x1: ndarray, x2: ndarray, /, xp, **kwargs) -> ndarray: +def matmul(x1: Array, x2: Array, /, xp: Namespace, **kwargs) -> Array: return xp.matmul(x1, x2, **kwargs) # Unlike transpose, matrix_transpose only transposes the last two axes. -def matrix_transpose(x: ndarray, /, xp) -> ndarray: +def matrix_transpose(x: Array, /, xp: Namespace) -> Array: if x.ndim < 2: raise ValueError("x must be at least 2-dimensional for matrix_transpose") return xp.swapaxes(x, -1, -2) -def tensordot(x1: ndarray, - x2: ndarray, - /, - xp, - *, - axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2, - **kwargs, -) -> ndarray: +def tensordot( + x1: Array, + x2: Array, + /, + xp: Namespace, + *, + axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2, + **kwargs, +) -> Array: return xp.tensordot(x1, x2, axes=axes, **kwargs) -def vecdot(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1) -> ndarray: +def vecdot(x1: Array, x2: Array, /, xp: Namespace, *, axis: int = -1) -> Array: if x1.shape[axis] != x2.shape[axis]: raise ValueError("x1 and x2 must have the same size along the given axis") @@ -511,8 +540,11 @@ def vecdot(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1) -> ndarray: # isdtype is a new function in the 2022.12 array API specification. def isdtype( - dtype: Dtype, kind: Union[Dtype, str, Tuple[Union[Dtype, str], ...]], xp, - *, _tuple=True, # Disallow nested tuples + dtype: DType, + kind: Union[DType, str, Tuple[Union[DType, str], ...]], + xp: Namespace, + *, + _tuple: bool = True, # Disallow nested tuples ) -> bool: """ Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``. @@ -551,14 +583,14 @@ def isdtype( return dtype == kind # unstack is a new function in the 2023.12 array API standard -def unstack(x: ndarray, /, xp, *, axis: int = 0) -> Tuple[ndarray, ...]: +def unstack(x: Array, /, xp: Namespace, *, axis: int = 0) -> Tuple[Array, ...]: if x.ndim == 0: raise ValueError("Input array must be at least 1-d.") return tuple(xp.moveaxis(x, axis, 0)) # numpy 1.26 does not use the standard definition for sign on complex numbers -def sign(x: ndarray, /, xp, **kwargs) -> ndarray: +def sign(x: Array, /, xp: Namespace, **kwargs) -> Array: if isdtype(x.dtype, 'complex floating', xp=xp): out = (x/xp.abs(x, **kwargs))[...] # sign(0) = 0 but the above formula would give nan diff --git a/array_api_compat/common/_fft.py b/array_api_compat/common/_fft.py index e5caebef..bd2a4e1a 100644 --- a/array_api_compat/common/_fft.py +++ b/array_api_compat/common/_fft.py @@ -1,149 +1,148 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Union, Optional, Literal +from collections.abc import Sequence +from typing import Union, Optional, Literal -if TYPE_CHECKING: - from ._typing import Device, ndarray, DType - from collections.abc import Sequence +from ._typing import Device, Array, DType, Namespace # Note: NumPy fft functions improperly upcast float32 and complex64 to # complex128, which is why we require wrapping them all here. def fft( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, n: Optional[int] = None, axis: int = -1, norm: Literal["backward", "ortho", "forward"] = "backward", -) -> ndarray: +) -> Array: res = xp.fft.fft(x, n=n, axis=axis, norm=norm) if x.dtype in [xp.float32, xp.complex64]: return res.astype(xp.complex64) return res def ifft( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, n: Optional[int] = None, axis: int = -1, norm: Literal["backward", "ortho", "forward"] = "backward", -) -> ndarray: +) -> Array: res = xp.fft.ifft(x, n=n, axis=axis, norm=norm) if x.dtype in [xp.float32, xp.complex64]: return res.astype(xp.complex64) return res def fftn( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, s: Sequence[int] = None, axes: Sequence[int] = None, norm: Literal["backward", "ortho", "forward"] = "backward", -) -> ndarray: +) -> Array: res = xp.fft.fftn(x, s=s, axes=axes, norm=norm) if x.dtype in [xp.float32, xp.complex64]: return res.astype(xp.complex64) return res def ifftn( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, s: Sequence[int] = None, axes: Sequence[int] = None, norm: Literal["backward", "ortho", "forward"] = "backward", -) -> ndarray: +) -> Array: res = xp.fft.ifftn(x, s=s, axes=axes, norm=norm) if x.dtype in [xp.float32, xp.complex64]: return res.astype(xp.complex64) return res def rfft( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, n: Optional[int] = None, axis: int = -1, norm: Literal["backward", "ortho", "forward"] = "backward", -) -> ndarray: +) -> Array: res = xp.fft.rfft(x, n=n, axis=axis, norm=norm) if x.dtype == xp.float32: return res.astype(xp.complex64) return res def irfft( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, n: Optional[int] = None, axis: int = -1, norm: Literal["backward", "ortho", "forward"] = "backward", -) -> ndarray: +) -> Array: res = xp.fft.irfft(x, n=n, axis=axis, norm=norm) if x.dtype == xp.complex64: return res.astype(xp.float32) return res def rfftn( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, s: Sequence[int] = None, axes: Sequence[int] = None, norm: Literal["backward", "ortho", "forward"] = "backward", -) -> ndarray: +) -> Array: res = xp.fft.rfftn(x, s=s, axes=axes, norm=norm) if x.dtype == xp.float32: return res.astype(xp.complex64) return res def irfftn( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, s: Sequence[int] = None, axes: Sequence[int] = None, norm: Literal["backward", "ortho", "forward"] = "backward", -) -> ndarray: +) -> Array: res = xp.fft.irfftn(x, s=s, axes=axes, norm=norm) if x.dtype == xp.complex64: return res.astype(xp.float32) return res def hfft( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, n: Optional[int] = None, axis: int = -1, norm: Literal["backward", "ortho", "forward"] = "backward", -) -> ndarray: +) -> Array: res = xp.fft.hfft(x, n=n, axis=axis, norm=norm) if x.dtype in [xp.float32, xp.complex64]: return res.astype(xp.float32) return res def ihfft( - x: ndarray, + x: Array, /, - xp, + xp: Namespace, *, n: Optional[int] = None, axis: int = -1, norm: Literal["backward", "ortho", "forward"] = "backward", -) -> ndarray: +) -> Array: res = xp.fft.ihfft(x, n=n, axis=axis, norm=norm) if x.dtype in [xp.float32, xp.complex64]: return res.astype(xp.complex64) @@ -152,12 +151,12 @@ def ihfft( def fftfreq( n: int, /, - xp, + xp: Namespace, *, d: float = 1.0, dtype: Optional[DType] = None, - device: Optional[Device] = None -) -> ndarray: + device: Optional[Device] = None, +) -> Array: if device not in ["cpu", None]: raise ValueError(f"Unsupported device {device!r}") res = xp.fft.fftfreq(n, d=d) @@ -168,12 +167,12 @@ def fftfreq( def rfftfreq( n: int, /, - xp, + xp: Namespace, *, d: float = 1.0, dtype: Optional[DType] = None, - device: Optional[Device] = None -) -> ndarray: + device: Optional[Device] = None, +) -> Array: if device not in ["cpu", None]: raise ValueError(f"Unsupported device {device!r}") res = xp.fft.rfftfreq(n, d=d) @@ -181,10 +180,14 @@ def rfftfreq( return res.astype(dtype) return res -def fftshift(x: ndarray, /, xp, *, axes: Union[int, Sequence[int]] = None) -> ndarray: +def fftshift( + x: Array, /, xp: Namespace, *, axes: Union[int, Sequence[int]] = None +) -> Array: return xp.fft.fftshift(x, axes=axes) -def ifftshift(x: ndarray, /, xp, *, axes: Union[int, Sequence[int]] = None) -> ndarray: +def ifftshift( + x: Array, /, xp: Namespace, *, axes: Union[int, Sequence[int]] = None +) -> Array: return xp.fft.ifftshift(x, axes=axes) __all__ = [ diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 791edb81..6d95069d 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -7,16 +7,14 @@ """ from __future__ import annotations -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from typing import Optional, Union, Any - from ._typing import Array, Device, Namespace - import sys import math import inspect import warnings +from typing import Optional, Union, Any + +from ._typing import Array, Device, Namespace + def _is_jax_zero_gradient_array(x: object) -> bool: """Return True if `x` is a zero-gradient array. @@ -268,7 +266,7 @@ def _compat_module_name() -> str: return __name__.removesuffix('.common._helpers') -def is_numpy_namespace(xp) -> bool: +def is_numpy_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is a NumPy namespace. @@ -289,7 +287,7 @@ def is_numpy_namespace(xp) -> bool: return xp.__name__ in {'numpy', _compat_module_name() + '.numpy'} -def is_cupy_namespace(xp) -> bool: +def is_cupy_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is a CuPy namespace. @@ -310,7 +308,7 @@ def is_cupy_namespace(xp) -> bool: return xp.__name__ in {'cupy', _compat_module_name() + '.cupy'} -def is_torch_namespace(xp) -> bool: +def is_torch_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is a PyTorch namespace. @@ -331,7 +329,7 @@ def is_torch_namespace(xp) -> bool: return xp.__name__ in {'torch', _compat_module_name() + '.torch'} -def is_ndonnx_namespace(xp) -> bool: +def is_ndonnx_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is an NDONNX namespace. @@ -350,7 +348,7 @@ def is_ndonnx_namespace(xp) -> bool: return xp.__name__ == 'ndonnx' -def is_dask_namespace(xp) -> bool: +def is_dask_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is a Dask namespace. @@ -371,7 +369,7 @@ def is_dask_namespace(xp) -> bool: return xp.__name__ in {'dask.array', _compat_module_name() + '.dask.array'} -def is_jax_namespace(xp) -> bool: +def is_jax_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is a JAX namespace. @@ -393,7 +391,7 @@ def is_jax_namespace(xp) -> bool: return xp.__name__ in {'jax.numpy', 'jax.experimental.array_api'} -def is_pydata_sparse_namespace(xp) -> bool: +def is_pydata_sparse_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is a pydata/sparse namespace. @@ -412,7 +410,7 @@ def is_pydata_sparse_namespace(xp) -> bool: return xp.__name__ == 'sparse' -def is_array_api_strict_namespace(xp) -> bool: +def is_array_api_strict_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is an array-api-strict namespace. @@ -439,7 +437,11 @@ def _check_api_version(api_version: str) -> None: raise ValueError("Only the 2024.12 version of the array API specification is currently supported") -def array_namespace(*xs, api_version=None, use_compat=None) -> Namespace: +def array_namespace( + *xs: Union[Array, bool, int, float, complex, None], + api_version: Optional[str] = None, + use_compat: Optional[bool] = None, +) -> Namespace: """ Get the array API compatible namespace for the arrays `xs`. diff --git a/array_api_compat/common/_linalg.py b/array_api_compat/common/_linalg.py index bfa1f1b9..c77ee3b8 100644 --- a/array_api_compat/common/_linalg.py +++ b/array_api_compat/common/_linalg.py @@ -1,11 +1,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, NamedTuple -if TYPE_CHECKING: - from typing import Literal, Optional, Tuple, Union - from ._typing import ndarray - import math +from typing import Literal, NamedTuple, Optional, Tuple, Union import numpy as np if np.__version__[0] == "2": @@ -15,50 +11,53 @@ from ._aliases import matmul, matrix_transpose, tensordot, vecdot, isdtype from .._internal import get_xp +from ._typing import Array, Namespace # These are in the main NumPy namespace but not in numpy.linalg -def cross(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1, **kwargs) -> ndarray: +def cross(x1: Array, x2: Array, /, xp: Namespace, *, axis: int = -1, **kwargs) -> Array: return xp.cross(x1, x2, axis=axis, **kwargs) -def outer(x1: ndarray, x2: ndarray, /, xp, **kwargs) -> ndarray: +def outer(x1: Array, x2: Array, /, xp: Namespace, **kwargs) -> Array: return xp.outer(x1, x2, **kwargs) class EighResult(NamedTuple): - eigenvalues: ndarray - eigenvectors: ndarray + eigenvalues: Array + eigenvectors: Array class QRResult(NamedTuple): - Q: ndarray - R: ndarray + Q: Array + R: Array class SlogdetResult(NamedTuple): - sign: ndarray - logabsdet: ndarray + sign: Array + logabsdet: Array class SVDResult(NamedTuple): - U: ndarray - S: ndarray - Vh: ndarray + U: Array + S: Array + Vh: Array # These functions are the same as their NumPy counterparts except they return # a namedtuple. -def eigh(x: ndarray, /, xp, **kwargs) -> EighResult: +def eigh(x: Array, /, xp: Namespace, **kwargs) -> EighResult: return EighResult(*xp.linalg.eigh(x, **kwargs)) -def qr(x: ndarray, /, xp, *, mode: Literal['reduced', 'complete'] = 'reduced', +def qr(x: Array, /, xp: Namespace, *, mode: Literal['reduced', 'complete'] = 'reduced', **kwargs) -> QRResult: return QRResult(*xp.linalg.qr(x, mode=mode, **kwargs)) -def slogdet(x: ndarray, /, xp, **kwargs) -> SlogdetResult: +def slogdet(x: Array, /, xp: Namespace, **kwargs) -> SlogdetResult: return SlogdetResult(*xp.linalg.slogdet(x, **kwargs)) -def svd(x: ndarray, /, xp, *, full_matrices: bool = True, **kwargs) -> SVDResult: +def svd( + x: Array, /, xp: Namespace, *, full_matrices: bool = True, **kwargs +) -> SVDResult: return SVDResult(*xp.linalg.svd(x, full_matrices=full_matrices, **kwargs)) # These functions have additional keyword arguments # The upper keyword argument is new from NumPy -def cholesky(x: ndarray, /, xp, *, upper: bool = False, **kwargs) -> ndarray: +def cholesky(x: Array, /, xp: Namespace, *, upper: bool = False, **kwargs) -> Array: L = xp.linalg.cholesky(x, **kwargs) if upper: U = get_xp(xp)(matrix_transpose)(L) @@ -69,12 +68,12 @@ def cholesky(x: ndarray, /, xp, *, upper: bool = False, **kwargs) -> ndarray: # The rtol keyword argument of matrix_rank() and pinv() is new from NumPy. # Note that it has a different semantic meaning from tol and rcond. -def matrix_rank(x: ndarray, +def matrix_rank(x: Array, /, - xp, + xp: Namespace, *, - rtol: Optional[Union[float, ndarray]] = None, - **kwargs) -> ndarray: + rtol: Optional[Union[float, Array]] = None, + **kwargs) -> Array: # this is different from xp.linalg.matrix_rank, which supports 1 # dimensional arrays. if x.ndim < 2: @@ -88,7 +87,9 @@ def matrix_rank(x: ndarray, tol = S.max(axis=-1, keepdims=True)*xp.asarray(rtol)[..., xp.newaxis] return xp.count_nonzero(S > tol, axis=-1) -def pinv(x: ndarray, /, xp, *, rtol: Optional[Union[float, ndarray]] = None, **kwargs) -> ndarray: +def pinv( + x: Array, /, xp: Namespace, *, rtol: Optional[Union[float, Array]] = None, **kwargs +) -> Array: # this is different from xp.linalg.pinv, which does not multiply the # default tolerance by max(M, N). if rtol is None: @@ -97,15 +98,30 @@ def pinv(x: ndarray, /, xp, *, rtol: Optional[Union[float, ndarray]] = None, **k # These functions are new in the array API spec -def matrix_norm(x: ndarray, /, xp, *, keepdims: bool = False, ord: Optional[Union[int, float, Literal['fro', 'nuc']]] = 'fro') -> ndarray: +def matrix_norm( + x: Array, + /, + xp: Namespace, + *, + keepdims: bool = False, + ord: Optional[Union[int, float, Literal['fro', 'nuc']]] = 'fro', +) -> Array: return xp.linalg.norm(x, axis=(-2, -1), keepdims=keepdims, ord=ord) # svdvals is not in NumPy (but it is in SciPy). It is equivalent to # xp.linalg.svd(compute_uv=False). -def svdvals(x: ndarray, /, xp) -> Union[ndarray, Tuple[ndarray, ...]]: +def svdvals(x: Array, /, xp: Namespace) -> Union[Array, Tuple[Array, ...]]: return xp.linalg.svd(x, compute_uv=False) -def vector_norm(x: ndarray, /, xp, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, ord: Optional[Union[int, float]] = 2) -> ndarray: +def vector_norm( + x: Array, + /, + xp: Namespace, + *, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + keepdims: bool = False, + ord: Optional[Union[int, float]] = 2, +) -> Array: # xp.linalg.norm tries to do a matrix norm whenever axis is a 2-tuple or # when axis=None and the input is 2-D, so to force a vector norm, we make # it so the input is 1-D (for axis=None), or reshape so that norm is done @@ -143,11 +159,15 @@ def vector_norm(x: ndarray, /, xp, *, axis: Optional[Union[int, Tuple[int, ...]] # xp.diagonal and xp.trace operate on the first two axes whereas these # operates on the last two -def diagonal(x: ndarray, /, xp, *, offset: int = 0, **kwargs) -> ndarray: +def diagonal(x: Array, /, xp: Namespace, *, offset: int = 0, **kwargs) -> Array: return xp.diagonal(x, offset=offset, axis1=-2, axis2=-1, **kwargs) -def trace(x: ndarray, /, xp, *, offset: int = 0, dtype=None, **kwargs) -> ndarray: - return xp.asarray(xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs)) +def trace( + x: Array, /, xp: Namespace, *, offset: int = 0, dtype=None, **kwargs +) -> Array: + return xp.asarray( + xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs) + ) __all__ = ['cross', 'matmul', 'outer', 'tensordot', 'EighResult', 'QRResult', 'SlogdetResult', 'SVDResult', 'eigh', 'qr', 'slogdet', diff --git a/array_api_compat/common/_typing.py b/array_api_compat/common/_typing.py index d8acdef7..4c3b356b 100644 --- a/array_api_compat/common/_typing.py +++ b/array_api_compat/common/_typing.py @@ -1,26 +1,24 @@ from __future__ import annotations +from types import ModuleType as Namespace +from typing import Any, TypeVar, Protocol __all__ = [ + "Array", + "DType", + "Device", + "Namespace", "NestedSequence", "SupportsBufferProtocol", ] -from types import ModuleType -from typing import ( - Any, - TypeVar, - Protocol, -) - _T_co = TypeVar("_T_co", covariant=True) class NestedSequence(Protocol[_T_co]): def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ... def __len__(self, /) -> int: ... -SupportsBufferProtocol = Any +SupportsBufferProtocol = Any Array = Any Device = Any DType = Any -Namespace = ModuleType diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index 30d9fe48..ebc7ccd9 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -1,16 +1,14 @@ from __future__ import annotations +from typing import Optional + import cupy as cp from ..common import _aliases, _helpers +from ..common._typing import NestedSequence, SupportsBufferProtocol from .._internal import get_xp - from ._info import __array_namespace_info__ - -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from typing import Optional, Union - from ._typing import ndarray, Device, Dtype, NestedSequence, SupportsBufferProtocol +from ._typing import Array, Device, DType bool = cp.bool_ @@ -66,23 +64,19 @@ _copy_default = object() + # asarray also adds the copy keyword, which is not present in numpy 1.0. def asarray( - obj: Union[ - ndarray, - bool, - int, - float, - NestedSequence[bool | int | float], - SupportsBufferProtocol, - ], + obj: ( + Array | bool | complex | NestedSequence[bool | complex] | SupportsBufferProtocol + ), /, *, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, copy: Optional[bool] = _copy_default, **kwargs, -) -> ndarray: +) -> Array: """ Array API compatibility wrapper for asarray(). @@ -112,13 +106,13 @@ def asarray( def astype( - x: ndarray, - dtype: Dtype, + x: Array, + dtype: DType, /, *, copy: bool = True, device: Optional[Device] = None, -) -> ndarray: +) -> Array: if device is None: return x.astype(dtype=dtype, copy=copy) out = _helpers.to_device(x.astype(dtype=dtype, copy=False), device) @@ -127,10 +121,10 @@ def astype( # cupy.count_nonzero does not have keepdims def count_nonzero( - x: ndarray, + x: Array, axis=None, keepdims=False -) -> ndarray: +) -> Array: result = cp.count_nonzero(x, axis) if keepdims: if axis is None: diff --git a/array_api_compat/cupy/_typing.py b/array_api_compat/cupy/_typing.py index f3d9aab6..66af5d19 100644 --- a/array_api_compat/cupy/_typing.py +++ b/array_api_compat/cupy/_typing.py @@ -1,46 +1,31 @@ from __future__ import annotations -__all__ = [ - "ndarray", - "Device", - "Dtype", -] +__all__ = ["Array", "DType", "Device"] +_all_ignore = ["cp"] -import sys -from typing import ( - Union, - TYPE_CHECKING, -) - -from cupy import ( - ndarray, - dtype, - int8, - int16, - int32, - int64, - uint8, - uint16, - uint32, - uint64, - float32, - float64, -) +from typing import TYPE_CHECKING +import cupy as cp +from cupy import ndarray as Array from cupy.cuda.device import Device -if TYPE_CHECKING or sys.version_info >= (3, 9): - Dtype = dtype[Union[ - int8, - int16, - int32, - int64, - uint8, - uint16, - uint32, - uint64, - float32, - float64, - ]] +if TYPE_CHECKING: + # NumPy 1.x on Python 3.9 and 3.10 fails to parse np.dtype[] + DType = cp.dtype[ + cp.intp + | cp.int8 + | cp.int16 + | cp.int32 + | cp.int64 + | cp.uint8 + | cp.uint16 + | cp.uint32 + | cp.uint64 + | cp.float32 + | cp.float64 + | cp.complex64 + | cp.complex128 + | cp.bool_ + ] else: - Dtype = dtype + DType = cp.dtype diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index 80d66281..e737cebd 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -1,16 +1,10 @@ from __future__ import annotations -from typing import Callable - -from ...common import _aliases, array_namespace - -from ..._internal import get_xp - -from ._info import __array_namespace_info__ +from typing import Callable, Optional, Union import numpy as np from numpy import ( - # Dtypes + # dtypes iinfo, finfo, bool_ as bool, @@ -29,22 +23,19 @@ can_cast, result_type, ) - -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from typing import Optional, Union - - from ...common._typing import ( - Device, - Dtype, - Array, - NestedSequence, - SupportsBufferProtocol, - ) - import dask.array as da +from ...common import _aliases, array_namespace +from ...common._typing import ( + Array, + Device, + DType, + NestedSequence, + SupportsBufferProtocol, +) +from ..._internal import get_xp +from ._info import __array_namespace_info__ + isdtype = get_xp(np)(_aliases.isdtype) unstack = get_xp(da)(_aliases.unstack) @@ -52,7 +43,7 @@ # da.astype doesn't respect copy=True def astype( x: Array, - dtype: Dtype, + dtype: DType, /, *, copy: bool = True, @@ -84,7 +75,7 @@ def arange( stop: Optional[Union[int, float]] = None, step: Union[int, float] = 1, *, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, **kwargs, ) -> Array: @@ -144,17 +135,12 @@ def arange( # asarray also adds the copy keyword, which is not present in numpy 1.0. def asarray( - obj: Union[ - Array, - bool, - int, - float, - NestedSequence[bool | int | float], - SupportsBufferProtocol, - ], + obj: ( + Array | bool | complex | NestedSequence[bool | complex] | SupportsBufferProtocol + ), /, *, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, copy: Optional[Union[bool, np._CopyMode]] = None, **kwargs, @@ -360,4 +346,4 @@ def count_nonzero( 'complex64', 'complex128', 'iinfo', 'finfo', 'can_cast', 'count_nonzero', 'result_type'] -_all_ignore = ["Callable", "array_namespace", "get_xp", "da", "np"] +_all_ignore = ["array_namespace", "get_xp", "da", "np"] diff --git a/array_api_compat/dask/array/fft.py b/array_api_compat/dask/array/fft.py index aebd86f7..3f40dffe 100644 --- a/array_api_compat/dask/array/fft.py +++ b/array_api_compat/dask/array/fft.py @@ -4,9 +4,10 @@ # from dask.array.fft import __all__ as linalg_all _n = {} exec('from dask.array.fft import *', _n) -del _n['__builtins__'] +for k in ("__builtins__", "Sequence", "annotations", "warnings"): + _n.pop(k, None) fft_all = list(_n) -del _n +del _n, k from ...common import _fft from ..._internal import get_xp @@ -16,9 +17,5 @@ fftfreq = get_xp(da)(_fft.fftfreq) rfftfreq = get_xp(da)(_fft.rfftfreq) -__all__ = [elem for elem in fft_all if elem != "annotations"] + ["fftfreq", "rfftfreq"] - -del get_xp -del da -del fft_all -del _fft +__all__ = fft_all + ["fftfreq", "rfftfreq"] +_all_ignore = ["da", "fft_all", "get_xp", "warnings"] diff --git a/array_api_compat/dask/array/linalg.py b/array_api_compat/dask/array/linalg.py index 49c26d8b..bd53f0df 100644 --- a/array_api_compat/dask/array/linalg.py +++ b/array_api_compat/dask/array/linalg.py @@ -1,33 +1,28 @@ from __future__ import annotations -from ...common import _linalg -from ..._internal import get_xp +from typing import Literal +import dask.array as da # Exports from dask.array.linalg import * # noqa: F403 from dask.array import outer - # These functions are in both the main and linalg namespaces from dask.array import matmul, tensordot -from ._aliases import matrix_transpose, vecdot -import dask.array as da - -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from ...common._typing import Array - from typing import Literal +from ..._internal import get_xp +from ...common import _linalg +from ...common._typing import Array +from ._aliases import matrix_transpose, vecdot # dask.array.linalg doesn't have __all__. If it is added, replace this with # # from dask.array.linalg import __all__ as linalg_all _n = {} exec('from dask.array.linalg import *', _n) -del _n['__builtins__'] -if 'annotations' in _n: - del _n['annotations'] +for k in ('__builtins__', 'annotations', 'operator', 'warnings', 'Array'): + _n.pop(k, None) linalg_all = list(_n) -del _n +del _n, k EighResult = _linalg.EighResult QRResult = _linalg.QRResult @@ -70,4 +65,4 @@ def svdvals(x: Array) -> Array: "cholesky", "matrix_rank", "matrix_norm", "svdvals", "vector_norm", "diagonal"] -_all_ignore = ['get_xp', 'da', 'linalg_all'] +_all_ignore = ['get_xp', 'da', 'linalg_all', 'warnings'] diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index a47f7121..6536d9a8 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -1,17 +1,15 @@ from __future__ import annotations -from ..common import _aliases +from typing import Optional, Union from .._internal import get_xp - +from ..common import _aliases +from ..common._typing import NestedSequence, SupportsBufferProtocol from ._info import __array_namespace_info__ - -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from typing import Optional, Union - from ._typing import ndarray, Device, Dtype, NestedSequence, SupportsBufferProtocol +from ._typing import Array, Device, DType import numpy as np + bool = np.bool_ # Basic renames @@ -64,6 +62,7 @@ tensordot = get_xp(np)(_aliases.tensordot) sign = get_xp(np)(_aliases.sign) + def _supports_buffer_protocol(obj): try: memoryview(obj) @@ -71,26 +70,22 @@ def _supports_buffer_protocol(obj): return False return True + # asarray also adds the copy keyword, which is not present in numpy 1.0. # asarray() is different enough between numpy, cupy, and dask, the logic # complicated enough that it's easier to define it separately for each module # rather than trying to combine everything into one function in common/ def asarray( - obj: Union[ - ndarray, - bool, - int, - float, - NestedSequence[bool | int | float], - SupportsBufferProtocol, - ], + obj: ( + Array | bool | complex | NestedSequence[bool | complex] | SupportsBufferProtocol + ), /, *, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, copy: "Optional[Union[bool, np._CopyMode]]" = None, **kwargs, -) -> ndarray: +) -> Array: """ Array API compatibility wrapper for asarray(). @@ -117,23 +112,19 @@ def asarray( def astype( - x: ndarray, - dtype: Dtype, + x: Array, + dtype: DType, /, *, copy: bool = True, device: Optional[Device] = None, -) -> ndarray: +) -> Array: return x.astype(dtype=dtype, copy=copy) # count_nonzero returns a python int for axis=None and keepdims=False # https://github.com/numpy/numpy/issues/17562 -def count_nonzero( - x : ndarray, - axis=None, - keepdims=False -) -> ndarray: +def count_nonzero(x: Array, axis=None, keepdims=False) -> Array: result = np.count_nonzero(x, axis=axis, keepdims=keepdims) if axis is None and not keepdims: return np.asarray(result) diff --git a/array_api_compat/numpy/_typing.py b/array_api_compat/numpy/_typing.py index c5ebb5ab..6a18a3b2 100644 --- a/array_api_compat/numpy/_typing.py +++ b/array_api_compat/numpy/_typing.py @@ -1,46 +1,31 @@ from __future__ import annotations -__all__ = [ - "ndarray", - "Device", - "Dtype", -] +__all__ = ["Array", "DType", "Device"] +_all_ignore = ["np"] -import sys -from typing import ( - Literal, - Union, - TYPE_CHECKING, -) +from typing import Literal, TYPE_CHECKING -from numpy import ( - ndarray, - dtype, - int8, - int16, - int32, - int64, - uint8, - uint16, - uint32, - uint64, - float32, - float64, -) +import numpy as np +from numpy import ndarray as Array Device = Literal["cpu"] -if TYPE_CHECKING or sys.version_info >= (3, 9): - Dtype = dtype[Union[ - int8, - int16, - int32, - int64, - uint8, - uint16, - uint32, - uint64, - float32, - float64, - ]] +if TYPE_CHECKING: + # NumPy 1.x on Python 3.9 and 3.10 fails to parse np.dtype[] + DType = np.dtype[ + np.intp + | np.int8 + | np.int16 + | np.int32 + | np.int64 + | np.uint8 + | np.uint16 + | np.uint32 + | np.uint64 + | np.float32 + | np.float64 + | np.complex64 + | np.complex128 + | np.bool + ] else: - Dtype = dtype + DType = np.dtype diff --git a/array_api_compat/py.typed b/array_api_compat/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 4b727f1c..87d32d85 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -2,21 +2,14 @@ from functools import reduce as _reduce, wraps as _wraps from builtins import all as _builtin_all, any as _builtin_any - -from ..common import _aliases -from .._internal import get_xp - -from ._info import __array_namespace_info__ +from typing import List, Optional, Sequence, Tuple, Union import torch -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from typing import List, Optional, Sequence, Tuple, Union - from ..common._typing import Device - from torch import dtype as Dtype - - array = torch.Tensor +from .._internal import get_xp +from ..common import _aliases +from ._info import __array_namespace_info__ +from ._typing import Array, Device, DType _int_dtypes = { torch.uint8, @@ -123,7 +116,7 @@ def _fix_promotion(x1, x2, only_scalar=True): _py_scalars = (bool, int, float, complex) -def result_type(*arrays_and_dtypes: Union[array, Dtype, bool, int, float, complex]) -> Dtype: +def result_type(*arrays_and_dtypes: Array | DType | complex) -> DType: num = len(arrays_and_dtypes) if num == 0: @@ -170,7 +163,7 @@ def _result_type(x, y): return torch.result_type(x, y) -def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool: +def can_cast(from_: Union[DType, Array], to: DType, /) -> bool: if not isinstance(from_, torch.dtype): from_ = from_.dtype return torch.can_cast(from_, to) @@ -216,13 +209,13 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool: # of 'axis'. # torch.min and torch.max return a tuple and don't support multiple axes https://github.com/pytorch/pytorch/issues/58745 -def max(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array: +def max(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Array: # https://github.com/pytorch/pytorch/issues/29137 if axis == (): return torch.clone(x) return torch.amax(x, axis, keepdims=keepdims) -def min(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> array: +def min(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Array: # https://github.com/pytorch/pytorch/issues/29137 if axis == (): return torch.clone(x) @@ -235,7 +228,7 @@ def min(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keep # torch.sort also returns a tuple # https://github.com/pytorch/pytorch/issues/70921 -def sort(x: array, /, *, axis: int = -1, descending: bool = False, stable: bool = True, **kwargs) -> array: +def sort(x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True, **kwargs) -> Array: return torch.sort(x, dim=axis, descending=descending, stable=stable, **kwargs).values def _normalize_axes(axis, ndim): @@ -280,13 +273,13 @@ def _reduce_multiple_axes(f, x, axis, keepdims=False, **kwargs): out = torch.unsqueeze(out, a) return out -def prod(x: array, +def prod(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, keepdims: bool = False, - **kwargs) -> array: + **kwargs) -> Array: x = torch.asarray(x) ndim = x.ndim @@ -316,13 +309,13 @@ def prod(x: array, return torch.prod(x, axis, dtype=dtype, keepdims=keepdims, **kwargs) -def sum(x: array, +def sum(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, keepdims: bool = False, - **kwargs) -> array: + **kwargs) -> Array: x = torch.asarray(x) ndim = x.ndim @@ -347,12 +340,12 @@ def sum(x: array, return torch.sum(x, axis, dtype=dtype, keepdims=keepdims, **kwargs) -def any(x: array, +def any(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, - **kwargs) -> array: + **kwargs) -> Array: x = torch.asarray(x) ndim = x.ndim if axis == (): @@ -372,12 +365,12 @@ def any(x: array, # torch.any doesn't return bool for uint8 return torch.any(x, axis, keepdims=keepdims).to(torch.bool) -def all(x: array, +def all(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, - **kwargs) -> array: + **kwargs) -> Array: x = torch.asarray(x) ndim = x.ndim if axis == (): @@ -397,12 +390,12 @@ def all(x: array, # torch.all doesn't return bool for uint8 return torch.all(x, axis, keepdims=keepdims).to(torch.bool) -def mean(x: array, +def mean(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, - **kwargs) -> array: + **kwargs) -> Array: # https://github.com/pytorch/pytorch/issues/29137 if axis == (): return torch.clone(x) @@ -414,13 +407,13 @@ def mean(x: array, return res return torch.mean(x, axis, keepdims=keepdims, **kwargs) -def std(x: array, +def std(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, correction: Union[int, float] = 0.0, keepdims: bool = False, - **kwargs) -> array: + **kwargs) -> Array: # Note, float correction is not supported # https://github.com/pytorch/pytorch/issues/61492. We don't try to # implement it here for now. @@ -445,13 +438,13 @@ def std(x: array, return res return torch.std(x, axis, correction=_correction, keepdims=keepdims, **kwargs) -def var(x: array, +def var(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, correction: Union[int, float] = 0.0, keepdims: bool = False, - **kwargs) -> array: + **kwargs) -> Array: # Note, float correction is not supported # https://github.com/pytorch/pytorch/issues/61492. We don't try to # implement it here for now. @@ -474,11 +467,11 @@ def var(x: array, # torch.concat doesn't support dim=None # https://github.com/pytorch/pytorch/issues/70925 -def concat(arrays: Union[Tuple[array, ...], List[array]], +def concat(arrays: Union[Tuple[Array, ...], List[Array]], /, *, axis: Optional[int] = 0, - **kwargs) -> array: + **kwargs) -> Array: if axis is None: arrays = tuple(ar.flatten() for ar in arrays) axis = 0 @@ -487,7 +480,7 @@ def concat(arrays: Union[Tuple[array, ...], List[array]], # torch.squeeze only accepts int dim and doesn't require it # https://github.com/pytorch/pytorch/issues/70924. Support for tuple dim was # added at https://github.com/pytorch/pytorch/pull/89017. -def squeeze(x: array, /, axis: Union[int, Tuple[int, ...]]) -> array: +def squeeze(x: Array, /, axis: Union[int, Tuple[int, ...]]) -> Array: if isinstance(axis, int): axis = (axis,) for a in axis: @@ -501,27 +494,27 @@ def squeeze(x: array, /, axis: Union[int, Tuple[int, ...]]) -> array: return x # torch.broadcast_to uses size instead of shape -def broadcast_to(x: array, /, shape: Tuple[int, ...], **kwargs) -> array: +def broadcast_to(x: Array, /, shape: Tuple[int, ...], **kwargs) -> Array: return torch.broadcast_to(x, shape, **kwargs) # torch.permute uses dims instead of axes -def permute_dims(x: array, /, axes: Tuple[int, ...]) -> array: +def permute_dims(x: Array, /, axes: Tuple[int, ...]) -> Array: return torch.permute(x, axes) # The axis parameter doesn't work for flip() and roll() # https://github.com/pytorch/pytorch/issues/71210. Also torch.flip() doesn't # accept axis=None -def flip(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs) -> array: +def flip(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs) -> Array: if axis is None: axis = tuple(range(x.ndim)) # torch.flip doesn't accept dim as an int but the method does # https://github.com/pytorch/pytorch/issues/18095 return x.flip(axis, **kwargs) -def roll(x: array, /, shift: Union[int, Tuple[int, ...]], *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs) -> array: +def roll(x: Array, /, shift: Union[int, Tuple[int, ...]], *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs) -> Array: return torch.roll(x, shift, axis, **kwargs) -def nonzero(x: array, /, **kwargs) -> Tuple[array, ...]: +def nonzero(x: Array, /, **kwargs) -> Tuple[Array, ...]: if x.ndim == 0: raise ValueError("nonzero() does not support zero-dimensional arrays") return torch.nonzero(x, as_tuple=True, **kwargs) @@ -529,25 +522,25 @@ def nonzero(x: array, /, **kwargs) -> Tuple[array, ...]: # torch uses `dim` instead of `axis` def diff( - x: array, + x: Array, /, *, axis: int = -1, n: int = 1, - prepend: Optional[array] = None, - append: Optional[array] = None, -) -> array: + prepend: Optional[Array] = None, + append: Optional[Array] = None, +) -> Array: return torch.diff(x, dim=axis, n=n, prepend=prepend, append=append) # torch uses `dim` instead of `axis`, does not have keepdims def count_nonzero( - x: array, + x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, -) -> array: +) -> Array: result = torch.count_nonzero(x, dim=axis) if keepdims: if axis is not None: @@ -557,17 +550,17 @@ def count_nonzero( return result - -def where(condition: array, x1: array, x2: array, /) -> array: +def where(condition: Array, x1: Array, x2: Array, /) -> Array: x1, x2 = _fix_promotion(x1, x2) return torch.where(condition, x1, x2) # torch.reshape doesn't have the copy keyword -def reshape(x: array, +def reshape(x: Array, /, shape: Tuple[int, ...], + *, copy: Optional[bool] = None, - **kwargs) -> array: + **kwargs) -> Array: if copy is not None: raise NotImplementedError("torch.reshape doesn't yet support the copy keyword") return torch.reshape(x, shape, **kwargs) @@ -581,9 +574,9 @@ def arange(start: Union[int, float], stop: Optional[Union[int, float]] = None, step: Union[int, float] = 1, *, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, - **kwargs) -> array: + **kwargs) -> Array: if stop is None: start, stop = 0, start if step > 0 and stop <= start or step < 0 and stop >= start: @@ -602,9 +595,9 @@ def eye(n_rows: int, /, *, k: int = 0, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, - **kwargs) -> array: + **kwargs) -> Array: if n_cols is None: n_cols = n_rows z = torch.zeros(n_rows, n_cols, dtype=dtype, device=device, **kwargs) @@ -618,10 +611,10 @@ def linspace(start: Union[int, float], /, num: int, *, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, endpoint: bool = True, - **kwargs) -> array: + **kwargs) -> Array: if not endpoint: return torch.linspace(start, stop, num+1, dtype=dtype, device=device, **kwargs)[:-1] return torch.linspace(start, stop, num, dtype=dtype, device=device, **kwargs) @@ -629,11 +622,11 @@ def linspace(start: Union[int, float], # torch.full does not accept an int size # https://github.com/pytorch/pytorch/issues/70906 def full(shape: Union[int, Tuple[int, ...]], - fill_value: Union[bool, int, float, complex], + fill_value: complex, *, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, - **kwargs) -> array: + **kwargs) -> Array: if isinstance(shape, int): shape = (shape,) @@ -642,52 +635,52 @@ def full(shape: Union[int, Tuple[int, ...]], # ones, zeros, and empty do not accept shape as a keyword argument def ones(shape: Union[int, Tuple[int, ...]], *, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, - **kwargs) -> array: + **kwargs) -> Array: return torch.ones(shape, dtype=dtype, device=device, **kwargs) def zeros(shape: Union[int, Tuple[int, ...]], *, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, - **kwargs) -> array: + **kwargs) -> Array: return torch.zeros(shape, dtype=dtype, device=device, **kwargs) def empty(shape: Union[int, Tuple[int, ...]], *, - dtype: Optional[Dtype] = None, + dtype: Optional[DType] = None, device: Optional[Device] = None, - **kwargs) -> array: + **kwargs) -> Array: return torch.empty(shape, dtype=dtype, device=device, **kwargs) # tril and triu do not call the keyword argument k -def tril(x: array, /, *, k: int = 0) -> array: +def tril(x: Array, /, *, k: int = 0) -> Array: return torch.tril(x, k) -def triu(x: array, /, *, k: int = 0) -> array: +def triu(x: Array, /, *, k: int = 0) -> Array: return torch.triu(x, k) # Functions that aren't in torch https://github.com/pytorch/pytorch/issues/58742 -def expand_dims(x: array, /, *, axis: int = 0) -> array: +def expand_dims(x: Array, /, *, axis: int = 0) -> Array: return torch.unsqueeze(x, axis) def astype( - x: array, - dtype: Dtype, + x: Array, + dtype: DType, /, *, copy: bool = True, device: Optional[Device] = None, -) -> array: +) -> Array: if device is not None: return x.to(device, dtype=dtype, copy=copy) return x.to(dtype=dtype, copy=copy) -def broadcast_arrays(*arrays: array) -> List[array]: +def broadcast_arrays(*arrays: Array) -> List[Array]: shape = torch.broadcast_shapes(*[a.shape for a in arrays]) return [torch.broadcast_to(a, shape) for a in arrays] @@ -697,7 +690,7 @@ def broadcast_arrays(*arrays: array) -> List[array]: UniqueInverseResult) # https://github.com/pytorch/pytorch/issues/70920 -def unique_all(x: array) -> UniqueAllResult: +def unique_all(x: Array) -> UniqueAllResult: # torch.unique doesn't support returning indices. # https://github.com/pytorch/pytorch/issues/36748. The workaround # suggested in that issue doesn't actually function correctly (it relies @@ -710,7 +703,7 @@ def unique_all(x: array) -> UniqueAllResult: # counts[torch.isnan(values)] = 1 # return UniqueAllResult(values, indices, inverse_indices, counts) -def unique_counts(x: array) -> UniqueCountsResult: +def unique_counts(x: Array) -> UniqueCountsResult: values, counts = torch.unique(x, return_counts=True) # torch.unique incorrectly gives a 0 count for nan values. @@ -718,14 +711,14 @@ def unique_counts(x: array) -> UniqueCountsResult: counts[torch.isnan(values)] = 1 return UniqueCountsResult(values, counts) -def unique_inverse(x: array) -> UniqueInverseResult: +def unique_inverse(x: Array) -> UniqueInverseResult: values, inverse = torch.unique(x, return_inverse=True) return UniqueInverseResult(values, inverse) -def unique_values(x: array) -> array: +def unique_values(x: Array) -> Array: return torch.unique(x) -def matmul(x1: array, x2: array, /, **kwargs) -> array: +def matmul(x1: Array, x2: Array, /, **kwargs) -> Array: # torch.matmul doesn't type promote (but differently from _fix_promotion) x1, x2 = _fix_promotion(x1, x2, only_scalar=False) return torch.matmul(x1, x2, **kwargs) @@ -733,12 +726,19 @@ def matmul(x1: array, x2: array, /, **kwargs) -> array: matrix_transpose = get_xp(torch)(_aliases.matrix_transpose) _vecdot = get_xp(torch)(_aliases.vecdot) -def vecdot(x1: array, x2: array, /, *, axis: int = -1) -> array: +def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: x1, x2 = _fix_promotion(x1, x2, only_scalar=False) return _vecdot(x1, x2, axis=axis) # torch.tensordot uses dims instead of axes -def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2, **kwargs) -> array: +def tensordot( + x1: Array, + x2: Array, + /, + *, + axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2, + **kwargs, +) -> Array: # Note: torch.tensordot fails with integer dtypes when there is only 1 # element in the axis (https://github.com/pytorch/pytorch/issues/84530). x1, x2 = _fix_promotion(x1, x2, only_scalar=False) @@ -746,7 +746,7 @@ def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int], def isdtype( - dtype: Dtype, kind: Union[Dtype, str, Tuple[Union[Dtype, str], ...]], + dtype: DType, kind: Union[DType, str, Tuple[Union[DType, str], ...]], *, _tuple=True, # Disallow nested tuples ) -> bool: """ @@ -781,7 +781,7 @@ def isdtype( else: return dtype == kind -def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -> array: +def take(x: Array, indices: Array, /, *, axis: Optional[int] = None, **kwargs) -> Array: if axis is None: if x.ndim != 1: raise ValueError("axis must be specified when ndim > 1") @@ -789,11 +789,11 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) - return torch.index_select(x, axis, indices, **kwargs) -def take_along_axis(x: array, indices: array, /, *, axis: int = -1) -> array: +def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array: return torch.take_along_dim(x, indices, dim=axis) -def sign(x: array, /) -> array: +def sign(x: Array, /) -> Array: # torch sign() does not support complex numbers and does not propagate # nans. See https://github.com/data-apis/array-api-compat/issues/136 if x.dtype.is_complex: diff --git a/array_api_compat/torch/_typing.py b/array_api_compat/torch/_typing.py new file mode 100644 index 00000000..29ad3fa7 --- /dev/null +++ b/array_api_compat/torch/_typing.py @@ -0,0 +1,4 @@ +__all__ = ["Array", "DType", "Device"] + +from torch import dtype as DType, Tensor as Array +from ..common._typing import Device diff --git a/array_api_compat/torch/fft.py b/array_api_compat/torch/fft.py index 3c9117ee..50e6a0d0 100644 --- a/array_api_compat/torch/fft.py +++ b/array_api_compat/torch/fft.py @@ -1,76 +1,75 @@ from __future__ import annotations -from typing import TYPE_CHECKING -if TYPE_CHECKING: - import torch - array = torch.Tensor - from typing import Union, Sequence, Literal +from typing import Union, Sequence, Literal -from torch.fft import * # noqa: F403 +import torch import torch.fft +from torch.fft import * # noqa: F403 + +from ._typing import Array # Several torch fft functions do not map axes to dim def fftn( - x: array, + x: Array, /, *, s: Sequence[int] = None, axes: Sequence[int] = None, norm: Literal["backward", "ortho", "forward"] = "backward", **kwargs, -) -> array: +) -> Array: return torch.fft.fftn(x, s=s, dim=axes, norm=norm, **kwargs) def ifftn( - x: array, + x: Array, /, *, s: Sequence[int] = None, axes: Sequence[int] = None, norm: Literal["backward", "ortho", "forward"] = "backward", **kwargs, -) -> array: +) -> Array: return torch.fft.ifftn(x, s=s, dim=axes, norm=norm, **kwargs) def rfftn( - x: array, + x: Array, /, *, s: Sequence[int] = None, axes: Sequence[int] = None, norm: Literal["backward", "ortho", "forward"] = "backward", **kwargs, -) -> array: +) -> Array: return torch.fft.rfftn(x, s=s, dim=axes, norm=norm, **kwargs) def irfftn( - x: array, + x: Array, /, *, s: Sequence[int] = None, axes: Sequence[int] = None, norm: Literal["backward", "ortho", "forward"] = "backward", **kwargs, -) -> array: +) -> Array: return torch.fft.irfftn(x, s=s, dim=axes, norm=norm, **kwargs) def fftshift( - x: array, + x: Array, /, *, axes: Union[int, Sequence[int]] = None, **kwargs, -) -> array: +) -> Array: return torch.fft.fftshift(x, dim=axes, **kwargs) def ifftshift( - x: array, + x: Array, /, *, axes: Union[int, Sequence[int]] = None, **kwargs, -) -> array: +) -> Array: return torch.fft.ifftshift(x, dim=axes, **kwargs) diff --git a/array_api_compat/torch/linalg.py b/array_api_compat/torch/linalg.py index e26198b9..7b59a670 100644 --- a/array_api_compat/torch/linalg.py +++ b/array_api_compat/torch/linalg.py @@ -1,14 +1,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING -if TYPE_CHECKING: - import torch - array = torch.Tensor - from torch import dtype as Dtype - from typing import Optional, Union, Tuple, Literal - inf = float('inf') - -from ._aliases import _fix_promotion, sum +import torch +from typing import Optional, Union, Tuple from torch.linalg import * # noqa: F403 @@ -19,15 +12,17 @@ # outer is implemented in torch but aren't in the linalg namespace from torch import outer +from ._aliases import _fix_promotion, sum # These functions are in both the main and linalg namespaces from ._aliases import matmul, matrix_transpose, tensordot +from ._typing import Array, DType # Note: torch.linalg.cross does not default to axis=-1 (it defaults to the # first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743 # torch.cross also does not support broadcasting when it would add new # dimensions https://github.com/pytorch/pytorch/issues/39656 -def cross(x1: array, x2: array, /, *, axis: int = -1) -> array: +def cross(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: x1, x2 = _fix_promotion(x1, x2, only_scalar=False) if not (-min(x1.ndim, x2.ndim) <= axis < max(x1.ndim, x2.ndim)): raise ValueError(f"axis {axis} out of bounds for cross product of arrays with shapes {x1.shape} and {x2.shape}") @@ -36,7 +31,7 @@ def cross(x1: array, x2: array, /, *, axis: int = -1) -> array: x1, x2 = torch.broadcast_tensors(x1, x2) return torch_linalg.cross(x1, x2, dim=axis) -def vecdot(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array: +def vecdot(x1: Array, x2: Array, /, *, axis: int = -1, **kwargs) -> Array: from ._aliases import isdtype x1, x2 = _fix_promotion(x1, x2, only_scalar=False) @@ -58,7 +53,7 @@ def vecdot(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array: return res[..., 0, 0] return torch.linalg.vecdot(x1, x2, dim=axis, **kwargs) -def solve(x1: array, x2: array, /, **kwargs) -> array: +def solve(x1: Array, x2: Array, /, **kwargs) -> Array: x1, x2 = _fix_promotion(x1, x2, only_scalar=False) # Torch tries to emulate NumPy 1 solve behavior by using batched 1-D solve # whenever @@ -79,19 +74,20 @@ def solve(x1: array, x2: array, /, **kwargs) -> array: return torch.linalg.solve(x1, x2, **kwargs) # torch.trace doesn't support the offset argument and doesn't support stacking -def trace(x: array, /, *, offset: int = 0, dtype: Optional[Dtype] = None) -> array: +def trace(x: Array, /, *, offset: int = 0, dtype: Optional[DType] = None) -> Array: # Use our wrapped sum to make sure it does upcasting correctly return sum(torch.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype) def vector_norm( - x: array, + x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, - ord: Union[int, float, Literal[inf, -inf]] = 2, + # float stands for inf | -inf, which are not valid for Literal + ord: Union[int, float, float] = 2, **kwargs, -) -> array: +) -> Array: # torch.vector_norm incorrectly treats axis=() the same as axis=None if axis == (): out = kwargs.get('out') diff --git a/setup.py b/setup.py index 3d2b68a2..2368ccc4 100644 --- a/setup.py +++ b/setup.py @@ -33,5 +33,8 @@ "Programming Language :: Python :: 3.13", "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", - ] + ], + package_data={ + "array_api_compat": ["py.typed"], + }, ) diff --git a/tests/test_all.py b/tests/test_all.py index d2e9b768..598fab62 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -15,6 +15,16 @@ from ._helpers import import_, wrapped_libraries import pytest +import typing + +TYPING_NAMES = frozenset(( + "Array", + "Device", + "DType", + "Namespace", + "NestedSequence", + "SupportsBufferProtocol", +)) @pytest.mark.skip(reason="TODO: starts failing after adding test_torch.py in gh-277") @pytest.mark.parametrize("library", ["common"] + wrapped_libraries) @@ -38,8 +48,11 @@ def test_all(library): dir_names = [n for n in dir(module) if not n.startswith('_')] if '__array_namespace_info__' in dir(module): dir_names.append('__array_namespace_info__') - ignore_all_names = getattr(module, '_all_ignore', []) - ignore_all_names += ['annotations', 'TYPE_CHECKING'] + ignore_all_names = set(getattr(module, '_all_ignore', ())) + ignore_all_names |= set(dir(typing)) + ignore_all_names |= {"annotations"} + if not module.__name__.endswith("._typing"): + ignore_all_names |= TYPING_NAMES dir_names = set(dir_names) - set(ignore_all_names) all_names = module.__all__ From 26845bd904ee66bb830463f46bb39f1cc5392275 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sat, 22 Mar 2025 11:12:31 +0100 Subject: [PATCH 014/151] Revert "TST: skip test_all" This reverts commit 5473d84d5c36b23e091b880279c863c32f41b828. --- tests/test_all.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_all.py b/tests/test_all.py index 598fab62..eeb67e4b 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -26,7 +26,6 @@ "SupportsBufferProtocol", )) -@pytest.mark.skip(reason="TODO: starts failing after adding test_torch.py in gh-277") @pytest.mark.parametrize("library", ["common"] + wrapped_libraries) def test_all(library): if library == "common": From 07a3cd41e1c5804b7c11d358400431e8a53a984a Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sat, 22 Mar 2025 11:40:02 +0100 Subject: [PATCH 015/151] MAINT: run self-tests even if a library is missing --- tests/test_array_namespace.py | 6 ++++-- tests/test_dask.py | 8 ++++++-- tests/test_jax.py | 8 ++++++-- tests/test_torch.py | 6 +++++- tests/test_vendoring.py | 2 ++ 5 files changed, 23 insertions(+), 7 deletions(-) diff --git a/tests/test_array_namespace.py b/tests/test_array_namespace.py index 605c69a1..cdb80007 100644 --- a/tests/test_array_namespace.py +++ b/tests/test_array_namespace.py @@ -2,10 +2,8 @@ import sys import warnings -import jax import numpy as np import pytest -import torch import array_api_compat from array_api_compat import array_namespace @@ -76,6 +74,7 @@ def test_array_namespace(library, api_version, use_compat): subprocess.run([sys.executable, "-c", code], check=True) def test_jax_zero_gradient(): + jax = import_("jax") jx = jax.numpy.arange(4) jax_zero = jax.vmap(jax.grad(jax.numpy.float32, allow_int=True))(jx) assert array_namespace(jax_zero) is array_namespace(jx) @@ -89,11 +88,13 @@ def test_array_namespace_errors(): pytest.raises(TypeError, lambda: array_namespace(x, (x, x))) def test_array_namespace_errors_torch(): + torch = import_("torch") y = torch.asarray([1, 2]) x = np.asarray([1, 2]) pytest.raises(TypeError, lambda: array_namespace(x, y)) def test_api_version_torch(): + torch = import_("torch") x = torch.asarray([1, 2]) torch_ = import_("torch", wrapper=True) assert array_namespace(x, api_version="2023.12") == torch_ @@ -118,6 +119,7 @@ def test_get_namespace(): assert array_api_compat.get_namespace is array_namespace def test_python_scalars(): + torch = import_("torch") a = torch.asarray([1, 2]) xp = import_("torch", wrapper=True) diff --git a/tests/test_dask.py b/tests/test_dask.py index be2b1e39..69c738f6 100644 --- a/tests/test_dask.py +++ b/tests/test_dask.py @@ -1,10 +1,14 @@ from contextlib import contextmanager import array_api_strict -import dask import numpy as np import pytest -import dask.array as da + +try: + import dask + import dask.array as da +except ImportError: + pytestmark = pytest.skip(allow_module_level=True, reason="dask not found") from array_api_compat import array_namespace diff --git a/tests/test_jax.py b/tests/test_jax.py index e33cec02..285958d4 100644 --- a/tests/test_jax.py +++ b/tests/test_jax.py @@ -1,10 +1,14 @@ -import jax -import jax.numpy as jnp from numpy.testing import assert_equal import pytest from array_api_compat import device, to_device +try: + import jax + import jax.numpy as jnp +except ImportError: + pytestmark = pytest.skip(allow_module_level=True, reason="jax not found") + HAS_JAX_0_4_31 = jax.__version__ >= "0.4.31" diff --git a/tests/test_torch.py b/tests/test_torch.py index 75b3a136..e8340f31 100644 --- a/tests/test_torch.py +++ b/tests/test_torch.py @@ -3,7 +3,11 @@ import itertools import pytest -import torch + +try: + import torch +except ImportError: + pytestmark = pytest.skip(allow_module_level=True, reason="pytorch not found") from array_api_compat import torch as xp diff --git a/tests/test_vendoring.py b/tests/test_vendoring.py index 70083b49..8b561551 100644 --- a/tests/test_vendoring.py +++ b/tests/test_vendoring.py @@ -16,11 +16,13 @@ def test_vendoring_cupy(): def test_vendoring_torch(): + pytest.importorskip("torch") from vendor_test import uses_torch uses_torch._test_torch() def test_vendoring_dask(): + pytest.importorskip("dask") from vendor_test import uses_dask uses_dask._test_dask() From 89466a6b43672b9a4a2dbdaea2896c24e4dcdd76 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sat, 22 Mar 2025 14:01:44 +0100 Subject: [PATCH 016/151] MAINT: common._aliases.__all__ --- array_api_compat/common/_aliases.py | 18 +++++++++++++----- tests/test_all.py | 2 +- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 0d123b99..0d1ecfbc 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -7,8 +7,14 @@ import inspect from typing import NamedTuple, Optional, Sequence, Tuple, Union -from ._helpers import array_namespace, _check_device, device, is_cupy_namespace from ._typing import Array, Device, DType, Namespace +from ._helpers import ( + array_namespace, + _check_device, + device as _get_device, + is_cupy_namespace as _is_cupy_namespace +) + # These functions are modified from the NumPy versions. @@ -298,7 +304,7 @@ def cumulative_sum( initial_shape = list(x.shape) initial_shape[axis] = 1 res = xp.concatenate( - [wrapped_xp.zeros(shape=initial_shape, dtype=res.dtype, device=device(res)), res], + [wrapped_xp.zeros(shape=initial_shape, dtype=res.dtype, device=_get_device(res)), res], axis=axis, ) return res @@ -328,7 +334,7 @@ def cumulative_prod( initial_shape = list(x.shape) initial_shape[axis] = 1 res = xp.concatenate( - [wrapped_xp.ones(shape=initial_shape, dtype=res.dtype, device=device(res)), res], + [wrapped_xp.ones(shape=initial_shape, dtype=res.dtype, device=_get_device(res)), res], axis=axis, ) return res @@ -381,7 +387,7 @@ def _isscalar(a): if type(max) is int and max >= wrapped_xp.iinfo(x.dtype).max: max = None - dev = device(x) + dev = _get_device(x) if out is None: out = wrapped_xp.empty(result_shape, dtype=x.dtype, device=dev) out[()] = x @@ -599,7 +605,7 @@ def sign(x: Array, /, xp: Namespace, **kwargs) -> Array: out = xp.sign(x, **kwargs) # CuPy sign() does not propagate nans. See # https://github.com/data-apis/array-api-compat/issues/136 - if is_cupy_namespace(xp) and isdtype(x.dtype, 'real floating', xp=xp): + if _is_cupy_namespace(xp) and isdtype(x.dtype, 'real floating', xp=xp): out[xp.isnan(x)] = xp.nan return out[()] @@ -611,3 +617,5 @@ def sign(x: Array, /, xp: Namespace, **kwargs) -> Array: 'reshape', 'argsort', 'sort', 'nonzero', 'ceil', 'floor', 'trunc', 'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype', 'unstack', 'sign'] + +_all_ignore = ['inspect', 'array_namespace', 'NamedTuple'] diff --git a/tests/test_all.py b/tests/test_all.py index eeb67e4b..4df4a361 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -33,7 +33,7 @@ def test_all(library): else: import_(library, wrapper=True) - for mod_name in sys.modules: + for mod_name in sys.modules.copy(): if not mod_name.startswith('array_api_compat.' + library): continue From 23841dfdb319fbb66a4065e0c138235c56e611f0 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sun, 23 Mar 2025 09:28:03 +0100 Subject: [PATCH 017/151] TST: update the torch skiplist --- torch-xfails.txt | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torch-xfails.txt b/torch-xfails.txt index 6e8f7dc6..f8333d90 100644 --- a/torch-xfails.txt +++ b/torch-xfails.txt @@ -144,10 +144,12 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_sc array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[logaddexp] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[maximum] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[minimum] + +# https://github.com/pytorch/pytorch/issues/149815 array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[equal] -array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[neq] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[not_equal] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[less] -array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[les_equal] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[less_equal] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[greater] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[greater_equal] From 3b4ea593d43c3d522aa1e601a93781774606bbc3 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sun, 23 Mar 2025 09:33:26 +0100 Subject: [PATCH 018/151] TST: update numpy<2 skiplists --- numpy-1-21-xfails.txt | 1 + numpy-1-26-xfails.txt | 1 + 2 files changed, 2 insertions(+) diff --git a/numpy-1-21-xfails.txt b/numpy-1-21-xfails.txt index 7c7a0757..30cde668 100644 --- a/numpy-1-21-xfails.txt +++ b/numpy-1-21-xfails.txt @@ -192,6 +192,7 @@ array_api_tests/test_signatures.py::test_func_signature[bitwise_left_shift] array_api_tests/test_signatures.py::test_func_signature[bitwise_or] array_api_tests/test_signatures.py::test_func_signature[bitwise_right_shift] array_api_tests/test_signatures.py::test_func_signature[bitwise_xor] +array_api_tests/test_data_type_functions.py::TestResultType::test_with_scalars # Stubs have a comment: (**note**: libraries may return ``NaN`` to match Python behavior.); Apparently,NumPy does just that array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] diff --git a/numpy-1-26-xfails.txt b/numpy-1-26-xfails.txt index 57259b6f..1ce28ef4 100644 --- a/numpy-1-26-xfails.txt +++ b/numpy-1-26-xfails.txt @@ -46,6 +46,7 @@ array_api_tests/test_signatures.py::test_func_signature[bitwise_left_shift] array_api_tests/test_signatures.py::test_func_signature[bitwise_or] array_api_tests/test_signatures.py::test_func_signature[bitwise_right_shift] array_api_tests/test_signatures.py::test_func_signature[bitwise_xor] +array_api_tests/test_data_type_functions.py::TestResultType::test_with_scalars # Stubs have a comment: (**note**: libraries may return ``NaN`` to match Python behavior.); Apparently, NumPy does just that array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] From f19256e3e132f0c16147936d1cf320680366055a Mon Sep 17 00:00:00 2001 From: Neil Girdhar Date: Fri, 21 Mar 2025 07:02:22 -0400 Subject: [PATCH 019/151] Add pyprject.toml --- .github/workflows/docs-build.yml | 2 +- .github/workflows/tests.yml | 6 +- docs/dev/tests.md | 2 +- docs/requirements.txt | 6 -- pyproject.toml | 96 ++++++++++++++++++++++++++++++++ requirements-dev.txt | 8 --- ruff.toml | 17 ------ setup.py | 40 ------------- 8 files changed, 99 insertions(+), 78 deletions(-) delete mode 100644 docs/requirements.txt create mode 100644 pyproject.toml delete mode 100644 requirements-dev.txt delete mode 100644 ruff.toml delete mode 100644 setup.py diff --git a/.github/workflows/docs-build.yml b/.github/workflows/docs-build.yml index 04c3aa66..34b9cbc6 100644 --- a/.github/workflows/docs-build.yml +++ b/.github/workflows/docs-build.yml @@ -10,7 +10,7 @@ jobs: - uses: actions/setup-python@v5 - name: Install Dependencies run: | - python -m pip install -r docs/requirements.txt + python -m pip install .[docs] - name: Build Docs run: | cd docs diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index fcd43367..54f6f402 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -29,11 +29,7 @@ jobs: PIP_EXTRA='numpy==1.26.*' fi - if [ "${{ matrix.python-version }}" == "3.9" ]; then - sed -i '/^ndonnx/d' requirements-dev.txt - fi - - python -m pip install -r requirements-dev.txt $PIP_EXTRA + python -m pip install .[dev] $PIP_EXTRA - name: Run Tests run: | diff --git a/docs/dev/tests.md b/docs/dev/tests.md index 6d9d1d7b..18fb7cf5 100644 --- a/docs/dev/tests.md +++ b/docs/dev/tests.md @@ -7,7 +7,7 @@ the array API standard. There are also array-api-compat specific tests in These tests should be limited to things that are not tested by the test suite, e.g., tests for [helper functions](../helper-functions.rst) or for behavior that is not strictly required by the standard. To run these tests, install the -dependencies from `requirements-dev.txt` (array-api-compat has [no hard +dependencies from the `dev` optional group (array-api-compat has [no hard runtime dependencies](no-dependencies)). array-api-tests is run against all supported libraries are tested on CI diff --git a/docs/requirements.txt b/docs/requirements.txt deleted file mode 100644 index dbec7740..00000000 --- a/docs/requirements.txt +++ /dev/null @@ -1,6 +0,0 @@ -furo -linkify-it-py -myst-parser -sphinx -sphinx-copybutton -sphinx-autobuild diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..f17c720f --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,96 @@ +[build-system] +requires = ["setuptools", "setuptools-scm"] +build-backend = "setuptools.build_meta" + +[project] +name = "array-api-compat" +dynamic = ["version"] +description = "A wrapper around NumPy and other array libraries to make them compatible with the Array API standard" +readme = "README.md" +requires-python = ">=3.9" +license = "MIT" +authors = [{name = "Consortium for Python Data API Standards"}] +classifiers = [ + "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Topic :: Software Development :: Libraries :: Python Modules", + "Typing :: Typed", +] + +[project.optional-dependencies] +cupy = ["cupy"] +dask = ["dask"] +jax = ["jax"] +numpy = ["numpy"] +pytorch = ["torch"] +sparse = ["sparse>=0.15.1"] +docs = [ + "furo", + "linkify-it-py", + "myst-parser", + "sphinx", + "sphinx-copybutton", + "sphinx-autobuild", +] +dev = [ + "array-api-strict", + "dask[array]", + "jax[cpu]", + "numpy", + "pytest", + "torch", + "sparse>=0.15.1", + "ndonnx; python_version>=\"3.10\"" +] + +[project.urls] +homepage = "https://data-apis.org/array-api-compat/" +repository = "https://github.com/data-apis/array-api-compat/" + +[tool.setuptools.dynamic] +version = {attr = "array_api_compat.__version__"} + +[tool.setuptools.packages.find] +include = ["array_api_compat*"] +namespaces = false + +[toolint] +preview = true +select = [ +# Defaults +"E4", "E7", "E9", "F", +# Undefined export +"F822", +# Useless import alias +"PLC0414" +] + +ignore = [ + # Module import not at top of file + "E402", + # Do not use bare `except` + "E722" +] + +[tool.ruff.lint] +preview = true +select = [ +# Defaults +"E4", "E7", "E9", "F", +# Undefined export +"F822", +# Useless import alias +"PLC0414" +] + +ignore = [ + # Module import not at top of file + "E402", + # Do not use bare `except` + "E722" +] diff --git a/requirements-dev.txt b/requirements-dev.txt deleted file mode 100644 index c9d10f71..00000000 --- a/requirements-dev.txt +++ /dev/null @@ -1,8 +0,0 @@ -array-api-strict -dask[array] -jax[cpu] -numpy -pytest -torch -sparse >=0.15.1 -ndonnx diff --git a/ruff.toml b/ruff.toml deleted file mode 100644 index 72e111b5..00000000 --- a/ruff.toml +++ /dev/null @@ -1,17 +0,0 @@ -[lint] -preview = true -select = [ -# Defaults -"E4", "E7", "E9", "F", -# Undefined export -"F822", -# Useless import alias -"PLC0414" -] - -ignore = [ - # Module import not at top of file - "E402", - # Do not use bare `except` - "E722" -] diff --git a/setup.py b/setup.py deleted file mode 100644 index 2368ccc4..00000000 --- a/setup.py +++ /dev/null @@ -1,40 +0,0 @@ -from setuptools import setup, find_packages - -with open("README.md", "r") as fh: - long_description = fh.read() - -import array_api_compat - -setup( - name='array_api_compat', - version=array_api_compat.__version__, - packages=find_packages(include=["array_api_compat*"]), - author="Consortium for Python Data API Standards", - description="A wrapper around NumPy and other array libraries to make them compatible with the Array API standard", - long_description=long_description, - long_description_content_type="text/markdown", - url="https://data-apis.org/array-api-compat/", - license="MIT", - extras_require={ - "numpy": "numpy", - "cupy": "cupy", - "jax": "jax", - "pytorch": "pytorch", - "dask": "dask", - "sparse": "sparse >=0.15.1", - }, - python_requires=">=3.9", - classifiers=[ - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", - "Programming Language :: Python :: 3.12", - "Programming Language :: Python :: 3.13", - "License :: OSI Approved :: MIT License", - "Operating System :: OS Independent", - ], - package_data={ - "array_api_compat": ["py.typed"], - }, -) From 1db3fae0f682199bda3ae920f8a695e4f579b439 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Tue, 25 Mar 2025 18:07:45 +0000 Subject: [PATCH 020/151] ENH: correct Dask capabilities --- array_api_compat/dask/array/_info.py | 22 ++++++++++++++++------ dask-xfails.txt | 8 +++++--- 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/array_api_compat/dask/array/_info.py b/array_api_compat/dask/array/_info.py index e15a69f4..fc70b5a2 100644 --- a/array_api_compat/dask/array/_info.py +++ b/array_api_compat/dask/array/_info.py @@ -68,11 +68,22 @@ def capabilities(self): The resulting dictionary has the following keys: - **"boolean indexing"**: boolean indicating whether an array library - supports boolean indexing. Always ``False`` for Dask. + supports boolean indexing. + + Dask support boolean indexing as long as both the index + and the indexed arrays have known shapes. + Note however that the output .shape and .size properties + will contain a non-compliant math.nan instead of None. - **"data-dependent shapes"**: boolean indicating whether an array - library supports data-dependent output shapes. Always ``False`` for - Dask. + library supports data-dependent output shapes. + + Dask implements unique_values et.al. + Note however that the output .shape and .size properties + will contain a non-compliant math.nan instead of None. + + - **"max dimensions"**: integer indicating the maximum number of + dimensions supported by the array library. See https://data-apis.org/array-api/latest/API_specification/generated/array_api.info.capabilities.html @@ -99,9 +110,8 @@ def capabilities(self): """ return { - "boolean indexing": False, - "data-dependent shapes": False, - # 'max rank' will be part of the 2024.12 standard + "boolean indexing": True, + "data-dependent shapes": True, "max dimensions": 64, } diff --git a/dask-xfails.txt b/dask-xfails.txt index d2474f9f..bd65d004 100644 --- a/dask-xfails.txt +++ b/dask-xfails.txt @@ -28,12 +28,14 @@ array_api_tests/test_has_names.py::test_has_names[array_method-to_device] array_api_tests/test_has_names.py::test_has_names[array_attribute-device] array_api_tests/test_has_names.py::test_has_names[array_attribute-mT] -# Fails because shape is NaN since we don't materialize it yet +# Data-dependent output shape +# These tests fail as array-api-tests doesn't cope with unknown shapes +# Also, output shape is (math.nan, ) instead of (None, ) +# Also, da.unique() doesn't accept equals_nan which causes non-compliant +# output when there are NaNs in the input. array_api_tests/test_searching_functions.py::test_nonzero array_api_tests/test_set_functions.py::test_unique_all array_api_tests/test_set_functions.py::test_unique_counts - -# Different error but same cause as above, we're just trying to do ndindex on nan shape array_api_tests/test_set_functions.py::test_unique_inverse array_api_tests/test_set_functions.py::test_unique_values From 71d90ead399c03f5fcbc15d205d7cedb6bc9825c Mon Sep 17 00:00:00 2001 From: Lucas Colley Date: Sun, 30 Mar 2025 09:19:56 +0100 Subject: [PATCH 021/151] Update test_all.py Co-authored-by: Evgeni Burovski --- tests/test_all.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_all.py b/tests/test_all.py index 4df4a361..271cd189 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -33,6 +33,7 @@ def test_all(library): else: import_(library, wrapper=True) + # NB: iterate over a copy to avoid a "dictionary size changed" error for mod_name in sys.modules.copy(): if not mod_name.startswith('array_api_compat.' + library): continue From b2af137864a484908fc96fddb1e47af56f0a4adf Mon Sep 17 00:00:00 2001 From: Guido Imperiale Date: Mon, 31 Mar 2025 23:51:14 +0100 Subject: [PATCH 022/151] TYP: Type annotations overhaul, part 2 (#291) --- array_api_compat/common/_aliases.py | 4 ++-- array_api_compat/cupy/_aliases.py | 5 ++++- array_api_compat/dask/array/_aliases.py | 5 ++++- array_api_compat/numpy/_aliases.py | 5 ++++- array_api_compat/torch/_aliases.py | 14 +++++++++++--- array_api_compat/torch/linalg.py | 2 +- 6 files changed, 26 insertions(+), 9 deletions(-) diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 0d1ecfbc..03910681 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -73,7 +73,7 @@ def eye( def full( shape: Union[int, Tuple[int, ...]], - fill_value: complex, + fill_value: bool | int | float | complex, xp: Namespace, *, dtype: Optional[DType] = None, @@ -86,7 +86,7 @@ def full( def full_like( x: Array, /, - fill_value: complex, + fill_value: bool | int | float | complex, *, xp: Namespace, dtype: Optional[DType] = None, diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index ebc7ccd9..423fd10a 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -68,7 +68,10 @@ # asarray also adds the copy keyword, which is not present in numpy 1.0. def asarray( obj: ( - Array | bool | complex | NestedSequence[bool | complex] | SupportsBufferProtocol + Array + | bool | int | float | complex + | NestedSequence[bool | int | float | complex] + | SupportsBufferProtocol ), /, *, diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index e737cebd..e6eff359 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -136,7 +136,10 @@ def arange( # asarray also adds the copy keyword, which is not present in numpy 1.0. def asarray( obj: ( - Array | bool | complex | NestedSequence[bool | complex] | SupportsBufferProtocol + Array + | bool | int | float | complex + | NestedSequence[bool | int | float | complex] + | SupportsBufferProtocol ), /, *, diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index 6536d9a8..1d084b2b 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -77,7 +77,10 @@ def _supports_buffer_protocol(obj): # rather than trying to combine everything into one function in common/ def asarray( obj: ( - Array | bool | complex | NestedSequence[bool | complex] | SupportsBufferProtocol + Array + | bool | int | float | complex + | NestedSequence[bool | int | float | complex] + | SupportsBufferProtocol ), /, *, diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 87d32d85..982500b0 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -116,7 +116,9 @@ def _fix_promotion(x1, x2, only_scalar=True): _py_scalars = (bool, int, float, complex) -def result_type(*arrays_and_dtypes: Array | DType | complex) -> DType: +def result_type( + *arrays_and_dtypes: Array | DType | bool | int | float | complex +) -> DType: num = len(arrays_and_dtypes) if num == 0: @@ -550,10 +552,16 @@ def count_nonzero( return result -def where(condition: Array, x1: Array, x2: Array, /) -> Array: +def where( + condition: Array, + x1: Array | bool | int | float | complex, + x2: Array | bool | int | float | complex, + /, +) -> Array: x1, x2 = _fix_promotion(x1, x2) return torch.where(condition, x1, x2) + # torch.reshape doesn't have the copy keyword def reshape(x: Array, /, @@ -622,7 +630,7 @@ def linspace(start: Union[int, float], # torch.full does not accept an int size # https://github.com/pytorch/pytorch/issues/70906 def full(shape: Union[int, Tuple[int, ...]], - fill_value: complex, + fill_value: bool | int | float | complex, *, dtype: Optional[DType] = None, device: Optional[Device] = None, diff --git a/array_api_compat/torch/linalg.py b/array_api_compat/torch/linalg.py index 7b59a670..1ff7319d 100644 --- a/array_api_compat/torch/linalg.py +++ b/array_api_compat/torch/linalg.py @@ -85,7 +85,7 @@ def vector_norm( axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, # float stands for inf | -inf, which are not valid for Literal - ord: Union[int, float, float] = 2, + ord: Union[int, float] = 2, **kwargs, ) -> Array: # torch.vector_norm incorrectly treats axis=() the same as axis=None From 29f494160a7657dc4da21113851bd6880e39dc7c Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 3 Apr 2025 10:26:59 +0100 Subject: [PATCH 023/151] TST: bump to ndonnx 0.10.1 --- tests/test_common.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/test_common.py b/tests/test_common.py index f86e0936..bbf14572 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -234,6 +234,7 @@ def test_asarray_cross_library(source_library, target_library, request): # TODO: remove xfail once # https://github.com/dask/dask/issues/8260 is resolved xfail(request, reason="Bug in dask raising error on conversion") + elif ( source_library == "ndonnx" and target_library not in ("array_api_strict", "ndonnx", "numpy") @@ -241,6 +242,9 @@ def test_asarray_cross_library(source_library, target_library, request): xfail(request, reason="The truth value of lazy Array Array(dtype=Boolean) is unknown") elif source_library == "ndonnx" and target_library == "numpy": xfail(request, reason="produces numpy array of ndonnx scalar arrays") + elif target_library == "ndonnx" and source_library in ("torch", "dask.array", "jax.numpy"): + xfail(request, reason="unable to infer dtype") + elif source_library == "jax.numpy" and target_library == "torch": xfail(request, reason="casts int to float") elif source_library == "cupy" and target_library != "cupy": From f80f15792ec981e943bef7f49faff687ef29b27c Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 2 Apr 2025 17:07:20 +0100 Subject: [PATCH 024/151] ENH: wrap iinfo/finfo --- array_api_compat/common/_aliases.py | 21 +++++++++++++++++++-- array_api_compat/cupy/_aliases.py | 2 ++ array_api_compat/dask/array/_aliases.py | 9 ++++----- array_api_compat/numpy/_aliases.py | 2 ++ array_api_compat/torch/_aliases.py | 5 ++++- cupy-xfails.txt | 11 ++++++++--- dask-xfails.txt | 10 ++++++++-- numpy-1-21-xfails.txt | 12 +++++++++--- numpy-1-26-xfails.txt | 10 ++++++++-- numpy-dev-xfails.txt | 10 ++++++++-- numpy-xfails.txt | 10 ++++++++-- torch-xfails.txt | 4 ++++ 12 files changed, 84 insertions(+), 22 deletions(-) diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 03910681..46cbb359 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -5,7 +5,7 @@ from __future__ import annotations import inspect -from typing import NamedTuple, Optional, Sequence, Tuple, Union +from typing import Any, NamedTuple, Optional, Sequence, Tuple, Union from ._typing import Array, Device, DType, Namespace from ._helpers import ( @@ -609,6 +609,23 @@ def sign(x: Array, /, xp: Namespace, **kwargs) -> Array: out[xp.isnan(x)] = xp.nan return out[()] + +def finfo(type_: DType | Array, /, xp: Namespace) -> Any: + # It is surprisingly difficult to recognize a dtype apart from an array. + # np.int64 is not the same as np.asarray(1).dtype! + try: + return xp.finfo(type_) + except (ValueError, TypeError): + return xp.finfo(type_.dtype) + + +def iinfo(type_: DType | Array, /, xp: Namespace) -> Any: + try: + return xp.iinfo(type_) + except (ValueError, TypeError): + return xp.iinfo(type_.dtype) + + __all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like', 'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like', 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult', @@ -616,6 +633,6 @@ def sign(x: Array, /, xp: Namespace, **kwargs) -> Array: 'std', 'var', 'cumulative_sum', 'cumulative_prod','clip', 'permute_dims', 'reshape', 'argsort', 'sort', 'nonzero', 'ceil', 'floor', 'trunc', 'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype', - 'unstack', 'sign'] + 'unstack', 'sign', 'finfo', 'iinfo'] _all_ignore = ['inspect', 'array_namespace', 'NamedTuple'] diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index 423fd10a..fd1460ae 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -61,6 +61,8 @@ matrix_transpose = get_xp(cp)(_aliases.matrix_transpose) tensordot = get_xp(cp)(_aliases.tensordot) sign = get_xp(cp)(_aliases.sign) +finfo = get_xp(cp)(_aliases.finfo) +iinfo = get_xp(cp)(_aliases.iinfo) _copy_default = object() diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index e6eff359..dca6d570 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -5,8 +5,6 @@ import numpy as np from numpy import ( # dtypes - iinfo, - finfo, bool_ as bool, float32, float64, @@ -131,6 +129,8 @@ def arange( matmul = get_xp(np)(_aliases.matmul) tensordot = get_xp(np)(_aliases.tensordot) sign = get_xp(np)(_aliases.sign) +finfo = get_xp(np)(_aliases.finfo) +iinfo = get_xp(np)(_aliases.iinfo) # asarray also adds the copy keyword, which is not present in numpy 1.0. @@ -343,10 +343,9 @@ def count_nonzero( '__array_namespace_info__', 'asarray', 'astype', 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atan2', 'atanh', 'bitwise_left_shift', 'bitwise_invert', - 'bitwise_right_shift', 'concat', 'pow', 'iinfo', 'finfo', 'can_cast', + 'bitwise_right_shift', 'concat', 'pow', 'can_cast', 'result_type', 'bool', 'float32', 'float64', 'int8', 'int16', 'int32', 'int64', - 'uint8', 'uint16', 'uint32', 'uint64', - 'complex64', 'complex128', 'iinfo', 'finfo', + 'uint8', 'uint16', 'uint32', 'uint64', 'complex64', 'complex128', 'can_cast', 'count_nonzero', 'result_type'] _all_ignore = ["array_namespace", "get_xp", "da", "np"] diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index 1d084b2b..ae0d006d 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -61,6 +61,8 @@ matrix_transpose = get_xp(np)(_aliases.matrix_transpose) tensordot = get_xp(np)(_aliases.tensordot) sign = get_xp(np)(_aliases.sign) +finfo = get_xp(np)(_aliases.finfo) +iinfo = get_xp(np)(_aliases.iinfo) def _supports_buffer_protocol(obj): diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 982500b0..9384e4c0 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -227,6 +227,9 @@ def min(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keep unstack = get_xp(torch)(_aliases.unstack) cumulative_sum = get_xp(torch)(_aliases.cumulative_sum) cumulative_prod = get_xp(torch)(_aliases.cumulative_prod) +finfo = get_xp(torch)(_aliases.finfo) +iinfo = get_xp(torch)(_aliases.iinfo) + # torch.sort also returns a tuple # https://github.com/pytorch/pytorch/issues/70921 @@ -832,6 +835,6 @@ def sign(x: Array, /) -> Array: 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult', 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values', 'matmul', 'matrix_transpose', 'vecdot', 'tensordot', 'isdtype', - 'take', 'take_along_axis', 'sign'] + 'take', 'take_along_axis', 'sign', 'finfo', 'iinfo'] _all_ignore = ['torch', 'get_xp'] diff --git a/cupy-xfails.txt b/cupy-xfails.txt index 3d20d745..f4cd1e36 100644 --- a/cupy-xfails.txt +++ b/cupy-xfails.txt @@ -14,9 +14,14 @@ array_api_tests/test_array_object.py::test_getitem # copy=False is not yet implemented array_api_tests/test_creation_functions.py::test_asarray_arrays -# finfo test is testing that the result is a float instead of float32 (see -# also https://github.com/data-apis/array-api/issues/405) -array_api_tests/test_data_type_functions.py::test_finfo[float32] +# attributes are np.float32 instead of float +# (see also https://github.com/data-apis/array-api/issues/405) +array_api_tests/test_data_type_functions.py::test_finfo[float32-dtype] +array_api_tests/test_data_type_functions.py::test_finfo[float32-array] +array_api_tests/test_data_type_functions.py::test_finfo[float32-array.dtype] +array_api_tests/test_data_type_functions.py::test_finfo[complex64-dtype] +array_api_tests/test_data_type_functions.py::test_finfo[complex64-array] +array_api_tests/test_data_type_functions.py::test_finfo[complex64-array.dtype] # Some array attributes are missing, and we do not wrap the array object array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__] diff --git a/dask-xfails.txt b/dask-xfails.txt index bd65d004..abab825c 100644 --- a/dask-xfails.txt +++ b/dask-xfails.txt @@ -12,8 +12,14 @@ array_api_tests/test_array_object.py::test_getitem_masking # zero division error, and typeerror: tuple indices must be integers or slices not tuple array_api_tests/test_creation_functions.py::test_eye -# finfo(float32).eps returns float32 but should return float -array_api_tests/test_data_type_functions.py::test_finfo[float32] +# attributes are np.float32 instead of float +# (see also https://github.com/data-apis/array-api/issues/405) +array_api_tests/test_data_type_functions.py::test_finfo[float32-dtype] +array_api_tests/test_data_type_functions.py::test_finfo[float32-array] +array_api_tests/test_data_type_functions.py::test_finfo[float32-array.dtype] +array_api_tests/test_data_type_functions.py::test_finfo[complex64-dtype] +array_api_tests/test_data_type_functions.py::test_finfo[complex64-array] +array_api_tests/test_data_type_functions.py::test_finfo[complex64-array.dtype] # out[-1]=dask.array but should be some floating number # (I think the test is not forcing the op to be computed?) diff --git a/numpy-1-21-xfails.txt b/numpy-1-21-xfails.txt index 30cde668..93a90757 100644 --- a/numpy-1-21-xfails.txt +++ b/numpy-1-21-xfails.txt @@ -1,8 +1,14 @@ # asarray(copy=False) is not yet implemented array_api_tests/test_creation_functions.py::test_asarray_arrays -# finfo(float32).eps returns float32 but should return float -array_api_tests/test_data_type_functions.py::test_finfo[float32] +# attributes are np.float32 instead of float +# (see also https://github.com/data-apis/array-api/issues/405) +array_api_tests/test_data_type_functions.py::test_finfo[float32-dtype] +array_api_tests/test_data_type_functions.py::test_finfo[float32-array] +array_api_tests/test_data_type_functions.py::test_finfo[float32-array.dtype] +array_api_tests/test_data_type_functions.py::test_finfo[complex64-dtype] +array_api_tests/test_data_type_functions.py::test_finfo[complex64-array] +array_api_tests/test_data_type_functions.py::test_finfo[complex64-array.dtype] # Array methods and attributes not already on np.ndarray cannot be wrapped array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__] @@ -41,7 +47,7 @@ array_api_tests/meta/test_hypothesis_helpers.py::test_symmetric_matrices ############################ # finfo has no smallest_normal -array_api_tests/test_data_type_functions.py::test_finfo[float64] +array_api_tests/test_data_type_functions.py::test_finfo # dlpack stuff array_api_tests/test_has_names.py::test_has_names[creation-from_dlpack] diff --git a/numpy-1-26-xfails.txt b/numpy-1-26-xfails.txt index 1ce28ef4..84916e73 100644 --- a/numpy-1-26-xfails.txt +++ b/numpy-1-26-xfails.txt @@ -1,5 +1,11 @@ -# finfo(float32).eps returns float32 but should return float -array_api_tests/test_data_type_functions.py::test_finfo[float32] +# attributes are np.float32 instead of float +# (see also https://github.com/data-apis/array-api/issues/405) +array_api_tests/test_data_type_functions.py::test_finfo[float32-dtype] +array_api_tests/test_data_type_functions.py::test_finfo[float32-array] +array_api_tests/test_data_type_functions.py::test_finfo[float32-array.dtype] +array_api_tests/test_data_type_functions.py::test_finfo[complex64-dtype] +array_api_tests/test_data_type_functions.py::test_finfo[complex64-array] +array_api_tests/test_data_type_functions.py::test_finfo[complex64-array.dtype] # Array methods and attributes not already on np.ndarray cannot be wrapped array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__] diff --git a/numpy-dev-xfails.txt b/numpy-dev-xfails.txt index 98659710..31bcb63b 100644 --- a/numpy-dev-xfails.txt +++ b/numpy-dev-xfails.txt @@ -1,5 +1,11 @@ -# finfo(float32).eps returns float32 but should return float -array_api_tests/test_data_type_functions.py::test_finfo[float32] +# attributes are np.float32 instead of float +# (see also https://github.com/data-apis/array-api/issues/405) +array_api_tests/test_data_type_functions.py::test_finfo[float32-dtype] +array_api_tests/test_data_type_functions.py::test_finfo[float32-array] +array_api_tests/test_data_type_functions.py::test_finfo[float32-array.dtype] +array_api_tests/test_data_type_functions.py::test_finfo[complex64-dtype] +array_api_tests/test_data_type_functions.py::test_finfo[complex64-array] +array_api_tests/test_data_type_functions.py::test_finfo[complex64-array.dtype] # The test suite cannot properly get the signature for vecdot # https://github.com/numpy/numpy/pull/26237 diff --git a/numpy-xfails.txt b/numpy-xfails.txt index 0885dcaa..0810aea6 100644 --- a/numpy-xfails.txt +++ b/numpy-xfails.txt @@ -1,5 +1,11 @@ -# finfo(float32).eps returns float32 but should return float -array_api_tests/test_data_type_functions.py::test_finfo[float32] +# attributes are np.float32 instead of float +# (see also https://github.com/data-apis/array-api/issues/405) +array_api_tests/test_data_type_functions.py::test_finfo[float32-dtype] +array_api_tests/test_data_type_functions.py::test_finfo[float32-array] +array_api_tests/test_data_type_functions.py::test_finfo[float32-array.dtype] +array_api_tests/test_data_type_functions.py::test_finfo[complex64-dtype] +array_api_tests/test_data_type_functions.py::test_finfo[complex64-array] +array_api_tests/test_data_type_functions.py::test_finfo[complex64-array.dtype] # The test suite cannot properly get the signature for vecdot # https://github.com/numpy/numpy/pull/26237 diff --git a/torch-xfails.txt b/torch-xfails.txt index f8333d90..e556fa4f 100644 --- a/torch-xfails.txt +++ b/torch-xfails.txt @@ -115,6 +115,10 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_round array_api_tests/test_set_functions.py::test_unique_counts array_api_tests/test_set_functions.py::test_unique_values +# finfo/iinfo.dtype is a string instead of a dtype +array_api_tests/test_data_type_functions.py::test_finfo_dtype +array_api_tests/test_data_type_functions.py::test_iinfo_dtype + # 2023.12 support array_api_tests/test_has_names.py::test_has_names[manipulation-repeat] array_api_tests/test_manipulation_functions.py::test_repeat From 37b1c475c98fb092135ef021f11b7f79cd46debd Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 3 Apr 2025 12:42:35 +0100 Subject: [PATCH 025/151] MAINT: validate device on numpy and dask --- array_api_compat/common/_helpers.py | 24 +++++++++++++++++++++--- array_api_compat/dask/array/_aliases.py | 5 ++++- array_api_compat/numpy/_aliases.py | 6 +++--- 3 files changed, 28 insertions(+), 7 deletions(-) diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 6d95069d..67c619b8 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -595,11 +595,29 @@ def your_function(x, y): # backwards compatibility alias get_namespace = array_namespace -def _check_device(xp, device): - if xp == sys.modules.get('numpy'): - if device not in ["cpu", None]: + +def _check_device(bare_xp, device): + """ + Validate dummy device on device-less array backends. + + Notes + ----- + This function is also invoked by CuPy, which does have multiple devices + if there are multiple GPUs available. + However, CuPy multi-device support is currently impossible + without using the global device or a context manager: + + https://github.com/data-apis/array-api-compat/pull/293 + """ + if bare_xp is sys.modules.get('numpy'): + if device not in ("cpu", None): raise ValueError(f"Unsupported device for NumPy: {device!r}") + elif bare_xp is sys.modules.get('dask.array'): + if device not in ("cpu", _DASK_DEVICE, None): + raise ValueError(f"Unsupported device for Dask: {device!r}") + + # Placeholder object to represent the dask device # when the array backend is not the CPU. # (since it is not easy to tell which device a dask array is on) diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index e6eff359..c5cd7489 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -25,7 +25,7 @@ ) import dask.array as da -from ...common import _aliases, array_namespace +from ...common import _aliases, _helpers, array_namespace from ...common._typing import ( Array, Device, @@ -56,6 +56,7 @@ def astype( specification for more details. """ # TODO: respect device keyword? + _helpers._check_device(da, device) if not copy and dtype == x.dtype: return x @@ -86,6 +87,7 @@ def arange( specification for more details. """ # TODO: respect device keyword? + _helpers._check_device(da, device) args = [start] if stop is not None: @@ -155,6 +157,7 @@ def asarray( specification for more details. """ # TODO: respect device keyword? + _helpers._check_device(da, device) if isinstance(obj, da.Array): if dtype is not None and dtype != obj.dtype: diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index 1d084b2b..d5b7feac 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -3,7 +3,7 @@ from typing import Optional, Union from .._internal import get_xp -from ..common import _aliases +from ..common import _aliases, _helpers from ..common._typing import NestedSequence, SupportsBufferProtocol from ._info import __array_namespace_info__ from ._typing import Array, Device, DType @@ -95,8 +95,7 @@ def asarray( See the corresponding documentation in the array library and/or the array API specification for more details. """ - if device not in ["cpu", None]: - raise ValueError(f"Unsupported device for NumPy: {device!r}") + _helpers._check_device(np, device) if hasattr(np, '_CopyMode'): if copy is None: @@ -122,6 +121,7 @@ def astype( copy: bool = True, device: Optional[Device] = None, ) -> Array: + _helpers._check_device(np, device) return x.astype(dtype=dtype, copy=copy) From 2c1cb6b515849048cd062e31462b6a193b81471c Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 3 Apr 2025 13:30:49 +0100 Subject: [PATCH 026/151] BUG: Don't import helpers in namespaces --- array_api_compat/common/_linalg.py | 2 ++ array_api_compat/cupy/__init__.py | 3 --- array_api_compat/dask/array/__init__.py | 1 + array_api_compat/numpy/__init__.py | 9 --------- array_api_compat/numpy/_aliases.py | 2 +- array_api_compat/torch/__init__.py | 6 ++---- tests/test_common.py | 2 +- 7 files changed, 7 insertions(+), 18 deletions(-) diff --git a/array_api_compat/common/_linalg.py b/array_api_compat/common/_linalg.py index c77ee3b8..d1e7ebd8 100644 --- a/array_api_compat/common/_linalg.py +++ b/array_api_compat/common/_linalg.py @@ -174,3 +174,5 @@ def trace( 'svd', 'cholesky', 'matrix_rank', 'pinv', 'matrix_norm', 'matrix_transpose', 'svdvals', 'vecdot', 'vector_norm', 'diagonal', 'trace'] + +_all_ignore = ['math', 'normalize_axis_tuple', 'get_xp', 'np', 'isdtype'] diff --git a/array_api_compat/cupy/__init__.py b/array_api_compat/cupy/__init__.py index 59e01058..9a30f95d 100644 --- a/array_api_compat/cupy/__init__.py +++ b/array_api_compat/cupy/__init__.py @@ -8,9 +8,6 @@ # See the comment in the numpy __init__.py __import__(__package__ + '.linalg') - __import__(__package__ + '.fft') -from ..common._helpers import * # noqa: F401,F403 - __array_api_version__ = '2024.12' diff --git a/array_api_compat/dask/array/__init__.py b/array_api_compat/dask/array/__init__.py index a6e69ad3..bb649306 100644 --- a/array_api_compat/dask/array/__init__.py +++ b/array_api_compat/dask/array/__init__.py @@ -5,5 +5,6 @@ __array_api_version__ = '2024.12' +# See the comment in the numpy __init__.py __import__(__package__ + '.linalg') __import__(__package__ + '.fft') diff --git a/array_api_compat/numpy/__init__.py b/array_api_compat/numpy/__init__.py index 02c55d28..6a5d9867 100644 --- a/array_api_compat/numpy/__init__.py +++ b/array_api_compat/numpy/__init__.py @@ -14,17 +14,8 @@ # It doesn't overwrite np.linalg from above. The import is generated # dynamically so that the library can be vendored. __import__(__package__ + '.linalg') - __import__(__package__ + '.fft') from .linalg import matrix_transpose, vecdot # noqa: F401 -from ..common._helpers import * # noqa: F403 - -try: - # Used in asarray(). Not present in older versions. - from numpy import _CopyMode # noqa: F401 -except ImportError: - pass - __array_api_version__ = '2024.12' diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index 1d084b2b..9e4f1174 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -86,7 +86,7 @@ def asarray( *, dtype: Optional[DType] = None, device: Optional[Device] = None, - copy: "Optional[Union[bool, np._CopyMode]]" = None, + copy: Optional[Union[bool, np._CopyMode]] = None, **kwargs, ) -> Array: """ diff --git a/array_api_compat/torch/__init__.py b/array_api_compat/torch/__init__.py index a985986e..69fd19ce 100644 --- a/array_api_compat/torch/__init__.py +++ b/array_api_compat/torch/__init__.py @@ -9,16 +9,14 @@ or 'cpu' in n or 'backward' in n): continue - exec(n + ' = torch.' + n) + exec(f"{n} = torch.{n}") +del n # These imports may overwrite names from the import * above. from ._aliases import * # noqa: F403 # See the comment in the numpy __init__.py __import__(__package__ + '.linalg') - __import__(__package__ + '.fft') -from ..common._helpers import * # noqa: F403 - __array_api_version__ = '2024.12' diff --git a/tests/test_common.py b/tests/test_common.py index bbf14572..54024d47 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -276,7 +276,7 @@ def test_asarray_copy(library): is_lib_func = globals()[is_array_functions[library]] all = xp.all if library != 'dask.array' else lambda x: xp.all(x).compute() - if library == 'numpy' and xp.__version__[0] < '2' and not hasattr(xp, '_CopyMode') : + if library == 'numpy' and xp.__version__[0] < '2' and not hasattr(np, "_CopyMode"): supports_copy_false_other_ns = False supports_copy_false_same_ns = False elif library == 'cupy': From 621494be1bd8682f1d76ae874272c12464953d3d Mon Sep 17 00:00:00 2001 From: Guido Imperiale Date: Fri, 4 Apr 2025 12:21:20 +0100 Subject: [PATCH 027/151] ENH: torch.asarray device propagation (#299) --- array_api_compat/torch/_aliases.py | 31 ++++++++++++++++++++++++------ array_api_compat/torch/_typing.py | 5 ++--- 2 files changed, 27 insertions(+), 9 deletions(-) diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 982500b0..0891525a 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -2,12 +2,13 @@ from functools import reduce as _reduce, wraps as _wraps from builtins import all as _builtin_all, any as _builtin_any -from typing import List, Optional, Sequence, Tuple, Union +from typing import Any, List, Optional, Sequence, Tuple, Union import torch from .._internal import get_xp from ..common import _aliases +from ..common._typing import NestedSequence, SupportsBufferProtocol from ._info import __array_namespace_info__ from ._typing import Array, Device, DType @@ -207,6 +208,28 @@ def can_cast(from_: Union[DType, Array], to: DType, /) -> bool: remainder = _two_arg(torch.remainder) subtract = _two_arg(torch.subtract) + +def asarray( + obj: ( + Array + | bool | int | float | complex + | NestedSequence[bool | int | float | complex] + | SupportsBufferProtocol + ), + /, + *, + dtype: DType | None = None, + device: Device | None = None, + copy: bool | None = None, + **kwargs: Any, +) -> Array: + # torch.asarray does not respect input->output device propagation + # https://github.com/pytorch/pytorch/issues/150199 + if device is None and isinstance(obj, torch.Tensor): + device = obj.device + return torch.asarray(obj, dtype=dtype, device=device, copy=copy, **kwargs) + + # These wrappers are mostly based on the fact that pytorch uses 'dim' instead # of 'axis'. @@ -282,7 +305,6 @@ def prod(x: Array, dtype: Optional[DType] = None, keepdims: bool = False, **kwargs) -> Array: - x = torch.asarray(x) ndim = x.ndim # https://github.com/pytorch/pytorch/issues/29137. Separate from the logic @@ -318,7 +340,6 @@ def sum(x: Array, dtype: Optional[DType] = None, keepdims: bool = False, **kwargs) -> Array: - x = torch.asarray(x) ndim = x.ndim # https://github.com/pytorch/pytorch/issues/29137. @@ -348,7 +369,6 @@ def any(x: Array, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, **kwargs) -> Array: - x = torch.asarray(x) ndim = x.ndim if axis == (): return x.to(torch.bool) @@ -373,7 +393,6 @@ def all(x: Array, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, **kwargs) -> Array: - x = torch.asarray(x) ndim = x.ndim if axis == (): return x.to(torch.bool) @@ -816,7 +835,7 @@ def sign(x: Array, /) -> Array: return out -__all__ = ['__array_namespace_info__', 'result_type', 'can_cast', +__all__ = ['__array_namespace_info__', 'asarray', 'result_type', 'can_cast', 'permute_dims', 'bitwise_invert', 'newaxis', 'conj', 'add', 'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'copysign', 'count_nonzero', diff --git a/array_api_compat/torch/_typing.py b/array_api_compat/torch/_typing.py index 29ad3fa7..52670871 100644 --- a/array_api_compat/torch/_typing.py +++ b/array_api_compat/torch/_typing.py @@ -1,4 +1,3 @@ -__all__ = ["Array", "DType", "Device"] +__all__ = ["Array", "Device", "DType"] -from torch import dtype as DType, Tensor as Array -from ..common._typing import Device +from torch import device as Device, dtype as DType, Tensor as Array From c629a64c928bd76fdf0bec28a1399467801364be Mon Sep 17 00:00:00 2001 From: crusaderky Date: Mon, 7 Apr 2025 10:27:05 +0100 Subject: [PATCH 028/151] Simplify test parametrization --- cupy-xfails.txt | 8 ++------ dask-xfails.txt | 8 ++------ numpy-1-21-xfails.txt | 8 ++------ numpy-1-26-xfails.txt | 8 ++------ numpy-dev-xfails.txt | 8 ++------ numpy-xfails.txt | 8 ++------ 6 files changed, 12 insertions(+), 36 deletions(-) diff --git a/cupy-xfails.txt b/cupy-xfails.txt index f4cd1e36..a30572f8 100644 --- a/cupy-xfails.txt +++ b/cupy-xfails.txt @@ -16,12 +16,8 @@ array_api_tests/test_creation_functions.py::test_asarray_arrays # attributes are np.float32 instead of float # (see also https://github.com/data-apis/array-api/issues/405) -array_api_tests/test_data_type_functions.py::test_finfo[float32-dtype] -array_api_tests/test_data_type_functions.py::test_finfo[float32-array] -array_api_tests/test_data_type_functions.py::test_finfo[float32-array.dtype] -array_api_tests/test_data_type_functions.py::test_finfo[complex64-dtype] -array_api_tests/test_data_type_functions.py::test_finfo[complex64-array] -array_api_tests/test_data_type_functions.py::test_finfo[complex64-array.dtype] +array_api_tests/test_data_type_functions.py::test_finfo[float32] +array_api_tests/test_data_type_functions.py::test_finfo[complex64] # Some array attributes are missing, and we do not wrap the array object array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__] diff --git a/dask-xfails.txt b/dask-xfails.txt index abab825c..932aeada 100644 --- a/dask-xfails.txt +++ b/dask-xfails.txt @@ -14,12 +14,8 @@ array_api_tests/test_creation_functions.py::test_eye # attributes are np.float32 instead of float # (see also https://github.com/data-apis/array-api/issues/405) -array_api_tests/test_data_type_functions.py::test_finfo[float32-dtype] -array_api_tests/test_data_type_functions.py::test_finfo[float32-array] -array_api_tests/test_data_type_functions.py::test_finfo[float32-array.dtype] -array_api_tests/test_data_type_functions.py::test_finfo[complex64-dtype] -array_api_tests/test_data_type_functions.py::test_finfo[complex64-array] -array_api_tests/test_data_type_functions.py::test_finfo[complex64-array.dtype] +array_api_tests/test_data_type_functions.py::test_finfo[float32] +array_api_tests/test_data_type_functions.py::test_finfo[complex64] # out[-1]=dask.array but should be some floating number # (I think the test is not forcing the op to be computed?) diff --git a/numpy-1-21-xfails.txt b/numpy-1-21-xfails.txt index 93a90757..66443a73 100644 --- a/numpy-1-21-xfails.txt +++ b/numpy-1-21-xfails.txt @@ -3,12 +3,8 @@ array_api_tests/test_creation_functions.py::test_asarray_arrays # attributes are np.float32 instead of float # (see also https://github.com/data-apis/array-api/issues/405) -array_api_tests/test_data_type_functions.py::test_finfo[float32-dtype] -array_api_tests/test_data_type_functions.py::test_finfo[float32-array] -array_api_tests/test_data_type_functions.py::test_finfo[float32-array.dtype] -array_api_tests/test_data_type_functions.py::test_finfo[complex64-dtype] -array_api_tests/test_data_type_functions.py::test_finfo[complex64-array] -array_api_tests/test_data_type_functions.py::test_finfo[complex64-array.dtype] +array_api_tests/test_data_type_functions.py::test_finfo[float32] +array_api_tests/test_data_type_functions.py::test_finfo[complex64] # Array methods and attributes not already on np.ndarray cannot be wrapped array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__] diff --git a/numpy-1-26-xfails.txt b/numpy-1-26-xfails.txt index 84916e73..ed95083a 100644 --- a/numpy-1-26-xfails.txt +++ b/numpy-1-26-xfails.txt @@ -1,11 +1,7 @@ # attributes are np.float32 instead of float # (see also https://github.com/data-apis/array-api/issues/405) -array_api_tests/test_data_type_functions.py::test_finfo[float32-dtype] -array_api_tests/test_data_type_functions.py::test_finfo[float32-array] -array_api_tests/test_data_type_functions.py::test_finfo[float32-array.dtype] -array_api_tests/test_data_type_functions.py::test_finfo[complex64-dtype] -array_api_tests/test_data_type_functions.py::test_finfo[complex64-array] -array_api_tests/test_data_type_functions.py::test_finfo[complex64-array.dtype] +array_api_tests/test_data_type_functions.py::test_finfo[float32] +array_api_tests/test_data_type_functions.py::test_finfo[complex64] # Array methods and attributes not already on np.ndarray cannot be wrapped array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__] diff --git a/numpy-dev-xfails.txt b/numpy-dev-xfails.txt index 31bcb63b..972d2346 100644 --- a/numpy-dev-xfails.txt +++ b/numpy-dev-xfails.txt @@ -1,11 +1,7 @@ # attributes are np.float32 instead of float # (see also https://github.com/data-apis/array-api/issues/405) -array_api_tests/test_data_type_functions.py::test_finfo[float32-dtype] -array_api_tests/test_data_type_functions.py::test_finfo[float32-array] -array_api_tests/test_data_type_functions.py::test_finfo[float32-array.dtype] -array_api_tests/test_data_type_functions.py::test_finfo[complex64-dtype] -array_api_tests/test_data_type_functions.py::test_finfo[complex64-array] -array_api_tests/test_data_type_functions.py::test_finfo[complex64-array.dtype] +array_api_tests/test_data_type_functions.py::test_finfo[float32] +array_api_tests/test_data_type_functions.py::test_finfo[complex64] # The test suite cannot properly get the signature for vecdot # https://github.com/numpy/numpy/pull/26237 diff --git a/numpy-xfails.txt b/numpy-xfails.txt index 0810aea6..0f09985e 100644 --- a/numpy-xfails.txt +++ b/numpy-xfails.txt @@ -1,11 +1,7 @@ # attributes are np.float32 instead of float # (see also https://github.com/data-apis/array-api/issues/405) -array_api_tests/test_data_type_functions.py::test_finfo[float32-dtype] -array_api_tests/test_data_type_functions.py::test_finfo[float32-array] -array_api_tests/test_data_type_functions.py::test_finfo[float32-array.dtype] -array_api_tests/test_data_type_functions.py::test_finfo[complex64-dtype] -array_api_tests/test_data_type_functions.py::test_finfo[complex64-array] -array_api_tests/test_data_type_functions.py::test_finfo[complex64-array.dtype] +array_api_tests/test_data_type_functions.py::test_finfo[float32] +array_api_tests/test_data_type_functions.py::test_finfo[complex64] # The test suite cannot properly get the signature for vecdot # https://github.com/numpy/numpy/pull/26237 From bff3bf467d6f126015179558f1b8c71242014cbc Mon Sep 17 00:00:00 2001 From: Guido Imperiale Date: Tue, 15 Apr 2025 11:48:45 +0100 Subject: [PATCH 029/151] Drop Python 3.9; test on Python 3.13; drop NumPy 1.21; skip CUDA install (#304) reviewed at https://github.com/data-apis/array-api-compat/pull/304 --- .github/workflows/array-api-tests-dask.yml | 2 +- .../workflows/array-api-tests-numpy-1-21.yml | 11 --- .../workflows/array-api-tests-numpy-1-22.yml | 12 +++ .../workflows/array-api-tests-numpy-1-26.yml | 1 + .../workflows/array-api-tests-numpy-dev.yml | 1 + .../array-api-tests-numpy-latest.yml | 3 +- .github/workflows/array-api-tests-torch.yml | 4 +- .github/workflows/array-api-tests.yml | 23 +++-- .github/workflows/tests.yml | 58 ++++++++----- array_api_compat/cupy/_typing.py | 2 +- array_api_compat/dask/array/_aliases.py | 2 +- array_api_compat/numpy/_aliases.py | 18 ++-- array_api_compat/numpy/_typing.py | 2 +- docs/supported-array-libraries.md | 17 +--- ...y-1-21-xfails.txt => numpy-1-22-xfails.txt | 83 +++---------------- numpy-1-26-xfails.txt | 3 - numpy-skips.txt | 11 --- numpy-xfails.txt | 4 +- pyproject.toml | 16 ++-- tests/test_common.py | 5 +- tests/test_dask.py | 6 +- torch-skips.txt | 11 --- torch-xfails.txt | 4 + 23 files changed, 114 insertions(+), 185 deletions(-) delete mode 100644 .github/workflows/array-api-tests-numpy-1-21.yml create mode 100644 .github/workflows/array-api-tests-numpy-1-22.yml rename numpy-1-21-xfails.txt => numpy-1-22-xfails.txt (68%) diff --git a/.github/workflows/array-api-tests-dask.yml b/.github/workflows/array-api-tests-dask.yml index 2ad98586..afc67975 100644 --- a/.github/workflows/array-api-tests-dask.yml +++ b/.github/workflows/array-api-tests-dask.yml @@ -7,7 +7,6 @@ jobs: uses: ./.github/workflows/array-api-tests.yml with: package-name: dask - package-version: '>= 2024.9.0' module-name: dask.array extra-requires: numpy # Dask is substantially slower then other libraries on unit tests. @@ -16,3 +15,4 @@ jobs: # flakiness. Before changes to dask-xfails.txt or dask-skips.txt, please run # the full test suite with at least 200 examples. pytest-extra-args: --max-examples=5 + python-versions: '[''3.10'', ''3.13'']' diff --git a/.github/workflows/array-api-tests-numpy-1-21.yml b/.github/workflows/array-api-tests-numpy-1-21.yml deleted file mode 100644 index 2d81c3cd..00000000 --- a/.github/workflows/array-api-tests-numpy-1-21.yml +++ /dev/null @@ -1,11 +0,0 @@ -name: Array API Tests (NumPy 1.21) - -on: [push, pull_request] - -jobs: - array-api-tests-numpy-1-21: - uses: ./.github/workflows/array-api-tests.yml - with: - package-name: numpy - package-version: '== 1.21.*' - xfails-file-extra: '-1-21' diff --git a/.github/workflows/array-api-tests-numpy-1-22.yml b/.github/workflows/array-api-tests-numpy-1-22.yml new file mode 100644 index 00000000..d8f60432 --- /dev/null +++ b/.github/workflows/array-api-tests-numpy-1-22.yml @@ -0,0 +1,12 @@ +name: Array API Tests (NumPy 1.22) + +on: [push, pull_request] + +jobs: + array-api-tests-numpy-1-22: + uses: ./.github/workflows/array-api-tests.yml + with: + package-name: numpy + package-version: '== 1.22.*' + xfails-file-extra: '-1-22' + python-versions: '[''3.10'']' diff --git a/.github/workflows/array-api-tests-numpy-1-26.yml b/.github/workflows/array-api-tests-numpy-1-26.yml index 660935f0..33780760 100644 --- a/.github/workflows/array-api-tests-numpy-1-26.yml +++ b/.github/workflows/array-api-tests-numpy-1-26.yml @@ -9,3 +9,4 @@ jobs: package-name: numpy package-version: '== 1.26.*' xfails-file-extra: '-1-26' + python-versions: '[''3.10'', ''3.12'']' diff --git a/.github/workflows/array-api-tests-numpy-dev.yml b/.github/workflows/array-api-tests-numpy-dev.yml index eef4269d..d6de1a53 100644 --- a/.github/workflows/array-api-tests-numpy-dev.yml +++ b/.github/workflows/array-api-tests-numpy-dev.yml @@ -9,3 +9,4 @@ jobs: package-name: numpy extra-requires: '--pre --extra-index https://pypi.anaconda.org/scientific-python-nightly-wheels/simple' xfails-file-extra: '-dev' + python-versions: '[''3.11'', ''3.13'']' diff --git a/.github/workflows/array-api-tests-numpy-latest.yml b/.github/workflows/array-api-tests-numpy-latest.yml index 36984345..4d3667f6 100644 --- a/.github/workflows/array-api-tests-numpy-latest.yml +++ b/.github/workflows/array-api-tests-numpy-latest.yml @@ -1,4 +1,4 @@ -name: Array API Tests (NumPy Latest) +name: Array API Tests (NumPy latest) on: [push, pull_request] @@ -7,3 +7,4 @@ jobs: uses: ./.github/workflows/array-api-tests.yml with: package-name: numpy + python-versions: '[''3.10'', ''3.13'']' diff --git a/.github/workflows/array-api-tests-torch.yml b/.github/workflows/array-api-tests-torch.yml index 56ab81a3..ac20df25 100644 --- a/.github/workflows/array-api-tests-torch.yml +++ b/.github/workflows/array-api-tests-torch.yml @@ -1,4 +1,4 @@ -name: Array API Tests (PyTorch Latest) +name: Array API Tests (PyTorch CPU) on: [push, pull_request] @@ -7,5 +7,7 @@ jobs: uses: ./.github/workflows/array-api-tests.yml with: package-name: torch + extra-requires: '--index-url https://download.pytorch.org/whl/cpu' extra-env-vars: | ARRAY_API_TESTS_SKIP_DTYPES=uint16,uint32,uint64 + python-versions: '[''3.10'', ''3.13'']' diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index 6ace193a..31bedde6 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -16,6 +16,10 @@ on: required: false type: string default: '>= 0' + python-versions: + required: true + type: string + description: JSON array of Python versions to test against. pytest-extra-args: required: false type: string @@ -30,7 +34,7 @@ on: extra-env-vars: required: false type: string - description: "Multiline string of environment variables to set for the test run." + description: Multiline string of environment variables to set for the test run. env: PYTEST_ARGS: "--max-examples 200 -v -rxXfE --ci ${{ inputs.pytest-extra-args }} --hypothesis-disable-deadline --durations 10" @@ -39,41 +43,44 @@ jobs: tests: runs-on: ubuntu-latest strategy: + fail-fast: false matrix: - # Min version of dask we need dropped support for Python 3.9 - # There is no numpy git tip for Python 3.9 or 3.10 - python-version: ${{ (inputs.package-name == 'dask' && fromJson('[''3.10'', ''3.11'', ''3.12'']')) || (inputs.package-name == 'numpy' && inputs.xfails-file-extra == '-dev' && fromJson('[''3.11'', ''3.12'']')) || fromJson('[''3.9'', ''3.10'', ''3.11'', ''3.12'']') }} + python-version: ${{ fromJson(inputs.python-versions) }} steps: - name: Checkout array-api-compat uses: actions/checkout@v4 with: path: array-api-compat + - name: Checkout array-api-tests uses: actions/checkout@v4 with: repository: data-apis/array-api-tests submodules: 'true' path: array-api-tests + - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} + - name: Set Extra Environment Variables # Set additional environment variables if provided if: inputs.extra-env-vars run: | echo "${{ inputs.extra-env-vars }}" >> $GITHUB_ENV + - name: Install dependencies - # NumPy 1.21 doesn't support Python 3.11. There doesn't seem to be a way - # to put this in the numpy 1.21 config file. - if: "! ((matrix.python-version == '3.11' || matrix.python-version == '3.12') && inputs.package-name == 'numpy' && contains(inputs.package-version, '1.21'))" run: | python -m pip install --upgrade pip python -m pip install '${{ inputs.package-name }} ${{ inputs.package-version }}' ${{ inputs.extra-requires }} python -m pip install -r ${GITHUB_WORKSPACE}/array-api-tests/requirements.txt + + - name: Dump pip environment + run: pip freeze + - name: Run the array API testsuite (${{ inputs.package-name }}) - if: "! ((matrix.python-version == '3.11' || matrix.python-version == '3.12') && inputs.package-name == 'numpy' && contains(inputs.package-version, '1.21'))" env: ARRAY_API_TESTS_MODULE: array_api_compat.${{ inputs.module-name || inputs.package-name }} ARRAY_API_TESTS_VERSION: 2024.12 diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 54f6f402..81a05b3f 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -4,15 +4,24 @@ jobs: tests: runs-on: ubuntu-latest strategy: + fail-fast: false matrix: - python-version: ['3.9', '3.10', '3.11', '3.12'] - numpy-version: ['1.21', '1.26', '2.0', 'dev'] - exclude: - - python-version: '3.11' - numpy-version: '1.21' - - python-version: '3.12' - numpy-version: '1.21' - fail-fast: true + include: + - numpy-version: '1.22' + python-version: '3.10' + - numpy-version: '1.26' + python-version: '3.10' + - numpy-version: '1.26' + python-version: '3.12' + - numpy-version: 'latest' + python-version: '3.10' + - numpy-version: 'latest' + python-version: '3.13' + - numpy-version: 'dev' + python-version: '3.11' + - numpy-version: 'dev' + python-version: '3.13' + steps: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 @@ -21,22 +30,29 @@ jobs: - name: Install Dependencies run: | python -m pip install --upgrade pip + python -m pip install pytest + if [ "${{ matrix.numpy-version }}" == "dev" ]; then - PIP_EXTRA='numpy --pre --extra-index-url https://pypi.anaconda.org/scientific-python-nightly-wheels/simple' - elif [ "${{ matrix.numpy-version }}" == "1.21" ]; then - PIP_EXTRA='numpy==1.21.*' + python -m pip install numpy --pre --extra-index-url https://pypi.anaconda.org/scientific-python-nightly-wheels/simple + elif [ "${{ matrix.numpy-version }}" == "1.22" ]; then + python -m pip install 'numpy==1.22.*' + elif [ "${{ matrix.numpy-version }}" == "1.26" ]; then + python -m pip install 'numpy==1.26.*' else - PIP_EXTRA='numpy==1.26.*' + # Don't `pip install .[dev]` as it would pull in the whole torch cuda stack + python -m pip install array-api-strict dask[array] jax[cpu] numpy sparse + python -m pip install torch --index-url https://download.pytorch.org/whl/cpu + if [ "${{ matrix.python-version }}" != "3.13" ]; then + # onnx wheels are not available on Python 3.13 at the moment of writing + python -m pip install ndonnx + fi fi - python -m pip install .[dev] $PIP_EXTRA + - name: Dump pip environment + run: pip freeze - - name: Run Tests - run: | - if [[ "${{ matrix.numpy-version }}" == "1.21" || "${{ matrix.numpy-version }}" == "dev" ]]; then - PYTEST_EXTRA=(-k "numpy and not jax and not torch and not dask and not sparse") - fi - pytest -v "${PYTEST_EXTRA[@]}" + - name: Test it installs + run: python -m pip install . - # Make sure it installs - python -m pip install . + - name: Run Tests + run: pytest -v diff --git a/array_api_compat/cupy/_typing.py b/array_api_compat/cupy/_typing.py index 66af5d19..d8e49ca7 100644 --- a/array_api_compat/cupy/_typing.py +++ b/array_api_compat/cupy/_typing.py @@ -10,7 +10,7 @@ from cupy.cuda.device import Device if TYPE_CHECKING: - # NumPy 1.x on Python 3.9 and 3.10 fails to parse np.dtype[] + # NumPy 1.x on Python 3.10 fails to parse np.dtype[] DType = cp.dtype[ cp.intp | cp.int8 diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index 4733b1a6..e7ddde78 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -147,7 +147,7 @@ def asarray( *, dtype: Optional[DType] = None, device: Optional[Device] = None, - copy: Optional[Union[bool, np._CopyMode]] = None, + copy: Optional[bool] = None, **kwargs, ) -> Array: """ diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index 59a0b8f4..d1fd46a1 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -99,18 +99,12 @@ def asarray( """ _helpers._check_device(np, device) - if hasattr(np, '_CopyMode'): - if copy is None: - copy = np._CopyMode.IF_NEEDED - elif copy is False: - copy = np._CopyMode.NEVER - elif copy is True: - copy = np._CopyMode.ALWAYS - else: - # Not present in older NumPys. In this case, we cannot really support - # copy=False. - if copy is False: - raise NotImplementedError("asarray(copy=False) requires a newer version of NumPy.") + if copy is None: + copy = np._CopyMode.IF_NEEDED + elif copy is False: + copy = np._CopyMode.NEVER + elif copy is True: + copy = np._CopyMode.ALWAYS return np.array(obj, copy=copy, dtype=dtype, **kwargs) diff --git a/array_api_compat/numpy/_typing.py b/array_api_compat/numpy/_typing.py index 6a18a3b2..a6c96924 100644 --- a/array_api_compat/numpy/_typing.py +++ b/array_api_compat/numpy/_typing.py @@ -10,7 +10,7 @@ Device = Literal["cpu"] if TYPE_CHECKING: - # NumPy 1.x on Python 3.9 and 3.10 fails to parse np.dtype[] + # NumPy 1.x on Python 3.10 fails to parse np.dtype[] DType = np.dtype[ np.intp | np.int8 diff --git a/docs/supported-array-libraries.md b/docs/supported-array-libraries.md index 4519c4ac..46fcdc27 100644 --- a/docs/supported-array-libraries.md +++ b/docs/supported-array-libraries.md @@ -36,23 +36,16 @@ deviations from the standard should be noted: 50](https://numpy.org/neps/nep-0050-scalar-promotion.html) and https://github.com/numpy/numpy/issues/22341) -- `asarray()` does not support `copy=False`. - - Functions which are not wrapped may not have the same type annotations as the spec. - Functions which are not wrapped may not use positional-only arguments. -The minimum supported NumPy version is 1.21. However, this older version of +The minimum supported NumPy version is 1.22. However, this older version of NumPy has a few issues: - `unique_*` will not compare nans as unequal. -- `finfo()` has no `smallest_normal`. - No `from_dlpack` or `__dlpack__`. -- `argmax()` and `argmin()` do not have `keepdims`. -- `qr()` doesn't support matrix stacks. -- `asarray()` doesn't support `copy=True` (as noted above, `copy=False` is not - supported even in the latest NumPy). - Type promotion behavior will be value based for 0-D arrays (and there is no `NPY_PROMOTION_STATE=weak` to disable this). @@ -72,8 +65,8 @@ version. attribute in the spec. Use the {func}`~.size()` helper function as a portable workaround. -- PyTorch does not have unsigned integer types other than `uint8`, and no - attempt is made to implement them here. +- PyTorch has incomplete support for unsigned integer types other + than `uint8`, and no attempt is made to implement them here. - PyTorch has type promotion semantics that differ from the array API specification for 0-D tensor objects. The array functions in this wrapper @@ -100,8 +93,6 @@ version. - As with NumPy, type annotations and positional-only arguments may not exactly match the spec for functions that are not wrapped at all. -The minimum supported PyTorch version is 1.13. - (jax-support)= ## [JAX](https://jax.readthedocs.io/en/latest/) @@ -131,8 +122,6 @@ For `linalg`, several methods are missing, for example: - `matrix_rank` Other methods may only be partially implemented or return incorrect results at times. -The minimum supported Dask version is 2023.12.0. - (sparse-support)= ## [Sparse](https://sparse.pydata.org/en/stable/) diff --git a/numpy-1-21-xfails.txt b/numpy-1-22-xfails.txt similarity index 68% rename from numpy-1-21-xfails.txt rename to numpy-1-22-xfails.txt index 66443a73..93edf311 100644 --- a/numpy-1-21-xfails.txt +++ b/numpy-1-22-xfails.txt @@ -1,6 +1,3 @@ -# asarray(copy=False) is not yet implemented -array_api_tests/test_creation_functions.py::test_asarray_arrays - # attributes are np.float32 instead of float # (see also https://github.com/data-apis/array-api/issues/405) array_api_tests/test_data_type_functions.py::test_finfo[float32] @@ -39,38 +36,24 @@ array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and # https://github.com/numpy/numpy/issues/21213 array_api_tests/meta/test_hypothesis_helpers.py::test_symmetric_matrices -# NumPy 1.21 specific XFAILS +# NumPy 1.22 specific XFAILS ############################ -# finfo has no smallest_normal -array_api_tests/test_data_type_functions.py::test_finfo - -# dlpack stuff -array_api_tests/test_has_names.py::test_has_names[creation-from_dlpack] -array_api_tests/test_has_names.py::test_has_names[array_method-__dlpack__] -array_api_tests/test_has_names.py::test_has_names[array_method-__dlpack_device__] -array_api_tests/test_signatures.py::test_func_signature[from_dlpack] -array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__] -array_api_tests/test_signatures.py::test_array_method_signature[__dlpack_device__] - -# qr() doesn't support matrix stacks -array_api_tests/test_linalg.py::test_qr - # cross has some promotion bug that is fixed in newer numpy versions array_api_tests/test_linalg.py::test_cross +# linspace(-0.0, -1.0, num=1) returns +0.0 instead of -0.0. +# Fixed in newer numpy versions. +array_api_tests/test_creation_functions.py::test_linspace + # vector_norm with ord=-1 which has since been fixed # https://github.com/numpy/numpy/issues/21083 array_api_tests/test_linalg.py::test_vector_norm -# argmax and argmin do not support keepdims -array_api_tests/test_searching_functions.py::test_argmax -array_api_tests/test_searching_functions.py::test_argmin -array_api_tests/test_signatures.py::test_func_signature[argmax] -array_api_tests/test_signatures.py::test_func_signature[argmin] - -# NumPy 1.21 doesn't support NPY_PROMOTION_STATE=weak, so many tests fail with +# NumPy 1.22 doesn't support NPY_PROMOTION_STATE=weak, so many tests fail with # type promotion issues +# NOTE: some of these may not fail until one runs array-api-tests with +# --max-examples 100000 array_api_tests/test_manipulation_functions.py::test_concat array_api_tests/test_operators_and_elementwise_functions.py::test_add[__add__(x, s)] array_api_tests/test_operators_and_elementwise_functions.py::test_add[__add__(x1, x2)] @@ -109,6 +92,7 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[_ array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[floor_divide(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_greater[__gt__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_greater[greater(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[__ge__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[greater_equal(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_hypot array_api_tests/test_operators_and_elementwise_functions.py::test_less[__lt__(x1, x2)] @@ -136,53 +120,11 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__isu array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__sub__(x, s)] array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__sub__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[subtract(x1, x2)] -array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i < 0 and x2_i is +0) -> -infinity] -array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i < 0 and x2_i is -0) -> +infinity] -array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i > 0 and x2_i is +0) -> +infinity] -array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i > 0 and x2_i is -0) -> -infinity] -array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +0 and x2_i < 0) -> -0] -array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is +0 and x2_i > 0) -> +0] -array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -0 and x2_i < 0) -> +0] -array_api_tests/test_special_cases.py::test_binary[__floordiv__(x1_i is -0 and x2_i > 0) -> -0] -array_api_tests/test_special_cases.py::test_binary[__mod__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> x2_i] -array_api_tests/test_special_cases.py::test_binary[__mod__(isfinite(x1_i) and x1_i < 0 and x2_i is -infinity) -> x1_i] -array_api_tests/test_special_cases.py::test_binary[__mod__(isfinite(x1_i) and x1_i > 0 and x2_i is +infinity) -> x1_i] -array_api_tests/test_special_cases.py::test_binary[__mod__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> x2_i] -array_api_tests/test_special_cases.py::test_binary[__mod__(x1_i is +0 and x2_i < 0) -> -0] -array_api_tests/test_special_cases.py::test_binary[__mod__(x1_i is +0 and x2_i > 0) -> +0] -array_api_tests/test_special_cases.py::test_binary[__mod__(x1_i is -0 and x2_i < 0) -> -0] -array_api_tests/test_special_cases.py::test_binary[__mod__(x1_i is -0 and x2_i > 0) -> +0] -array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i < 0 and isfinite(x1_i) and isfinite(x2_i) and not x2_i.is_integer()) -> NaN] -array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is +0 and x2_i < 0) -> +infinity] -array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is +0 and x2_i > 0) -> +0] -array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -0 and x2_i < 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +infinity] -array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -0 and x2_i < 0 and x2_i.is_integer() and x2_i % 2 == 1) -> -infinity] -array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -0 and x2_i > 0 and x2_i.is_integer() and x2_i % 2 == 1) -> -0] -array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -infinity and x2_i < 0 and not (x2_i.is_integer() and x2_i % 2 == 1)) -> +0] -array_api_tests/test_special_cases.py::test_binary[__pow__(x1_i is -infinity and x2_i < 0 and x2_i.is_integer() and x2_i % 2 == 1) -> -0] -array_api_tests/test_special_cases.py::test_binary[atan2(x1_i < 0 and x2_i is +0) -> roughly -pi/2] -array_api_tests/test_special_cases.py::test_binary[atan2(x1_i < 0 and x2_i is -0) -> roughly -pi/2] -array_api_tests/test_special_cases.py::test_binary[atan2(x1_i > 0 and x2_i is +0) -> roughly +pi/2] -array_api_tests/test_special_cases.py::test_binary[atan2(x1_i > 0 and x2_i is -0) -> roughly +pi/2] -array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i < 0 and x2_i is +0) -> -infinity] -array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i < 0 and x2_i is -0) -> +infinity] -array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i > 0 and x2_i is +0) -> +infinity] -array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i > 0 and x2_i is -0) -> -infinity] -array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +0 and x2_i < 0) -> -0] -array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +0 and x2_i > 0) -> +0] -array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -0 and x2_i < 0) -> +0] -array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -0 and x2_i > 0) -> -0] -array_api_tests/test_special_cases.py::test_binary[remainder(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> x2_i] -array_api_tests/test_special_cases.py::test_binary[remainder(isfinite(x1_i) and x1_i < 0 and x2_i is -infinity) -> x1_i] -array_api_tests/test_special_cases.py::test_binary[remainder(isfinite(x1_i) and x1_i > 0 and x2_i is +infinity) -> x1_i] -array_api_tests/test_special_cases.py::test_binary[remainder(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> x2_i] -array_api_tests/test_special_cases.py::test_binary[remainder(x1_i is +0 and x2_i < 0) -> -0] -array_api_tests/test_special_cases.py::test_binary[remainder(x1_i is +0 and x2_i > 0) -> +0] -array_api_tests/test_special_cases.py::test_binary[remainder(x1_i is -0 and x2_i < 0) -> -0] -array_api_tests/test_special_cases.py::test_binary[remainder(x1_i is -0 and x2_i > 0) -> +0] +array_api_tests/test_searching_functions.py::test_where array_api_tests/test_special_cases.py::test_iop[__iadd__(x1_i is -0 and x2_i is -0) -> -0] # 2023.12 support +array_api_tests/test_has_names.py::test_has_names[creation-from_dlpack] array_api_tests/test_signatures.py::test_func_signature[from_dlpack] array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__] # uint64 repeats not supported @@ -215,6 +157,3 @@ array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] - -# numpy < 2 bug: type promotion of asarray([], 'float32') and (np.finfo(float32).max + 1) -> float64 -array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real diff --git a/numpy-1-26-xfails.txt b/numpy-1-26-xfails.txt index ed95083a..51e1a658 100644 --- a/numpy-1-26-xfails.txt +++ b/numpy-1-26-xfails.txt @@ -69,6 +69,3 @@ array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] - -# numpy < 2 bug: type promotion of asarray([], 'float32') and (finfo(float32).max + 1) gives float64 not float32 -array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real diff --git a/numpy-skips.txt b/numpy-skips.txt index cbf7235b..e69de29b 100644 --- a/numpy-skips.txt +++ b/numpy-skips.txt @@ -1,11 +0,0 @@ -# These tests cause a core dump on CI, so we have to skip them entirely -array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__ifloordiv__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__ifloordiv__(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[floor_divide(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__imod__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__imod__(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__mod__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__mod__(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[remainder(x1, x2)] diff --git a/numpy-xfails.txt b/numpy-xfails.txt index 0f09985e..632b4ec3 100644 --- a/numpy-xfails.txt +++ b/numpy-xfails.txt @@ -9,8 +9,6 @@ array_api_tests/test_signatures.py::test_func_signature[vecdot] array_api_tests/test_signatures.py::test_extension_func_signature[linalg.vecdot] # 2023.12 support -array_api_tests/test_signatures.py::test_func_signature[from_dlpack] -array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__] # uint64 repeats not supported array_api_tests/test_manipulation_functions.py::test_repeat @@ -20,6 +18,8 @@ array_api_tests/test_signatures.py::test_func_signature[bitwise_left_shift] array_api_tests/test_signatures.py::test_func_signature[bitwise_or] array_api_tests/test_signatures.py::test_func_signature[bitwise_right_shift] array_api_tests/test_signatures.py::test_func_signature[bitwise_xor] + +# Stubs have a comment: (**note**: libraries may return ``NaN`` to match Python behavior.); Apparently, NumPy does just that array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] diff --git a/pyproject.toml b/pyproject.toml index f17c720f..aacebd11 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,13 +7,12 @@ name = "array-api-compat" dynamic = ["version"] description = "A wrapper around NumPy and other array libraries to make them compatible with the Array API standard" readme = "README.md" -requires-python = ">=3.9" +requires-python = ">=3.10" license = "MIT" authors = [{name = "Consortium for Python Data API Standards"}] classifiers = [ "Operating System :: OS Independent", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", @@ -24,11 +23,14 @@ classifiers = [ [project.optional-dependencies] cupy = ["cupy"] -dask = ["dask"] +dask = ["dask>=2024.9.0"] jax = ["jax"] -numpy = ["numpy"] +# Note: array-api-compat follows scikit-learn minimum dependencies, which support +# much older versions of NumPy than what SPEC0 recommends. +numpy = ["numpy>=1.22"] pytorch = ["torch"] sparse = ["sparse>=0.15.1"] +ndonnx = ["ndonnx"] docs = [ "furo", "linkify-it-py", @@ -39,13 +41,13 @@ docs = [ ] dev = [ "array-api-strict", - "dask[array]", + "dask[array]>=2024.9.0", "jax[cpu]", - "numpy", + "numpy>=1.22", "pytest", "torch", "sparse>=0.15.1", - "ndonnx; python_version>=\"3.10\"" + "ndonnx" ] [project.urls] diff --git a/tests/test_common.py b/tests/test_common.py index 54024d47..6b1aa160 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -276,10 +276,7 @@ def test_asarray_copy(library): is_lib_func = globals()[is_array_functions[library]] all = xp.all if library != 'dask.array' else lambda x: xp.all(x).compute() - if library == 'numpy' and xp.__version__[0] < '2' and not hasattr(np, "_CopyMode"): - supports_copy_false_other_ns = False - supports_copy_false_same_ns = False - elif library == 'cupy': + if library == 'cupy': supports_copy_false_other_ns = False supports_copy_false_same_ns = False elif library == 'dask.array': diff --git a/tests/test_dask.py b/tests/test_dask.py index 69c738f6..fb0a84d4 100644 --- a/tests/test_dask.py +++ b/tests/test_dask.py @@ -1,6 +1,5 @@ from contextlib import contextmanager -import array_api_strict import numpy as np import pytest @@ -171,9 +170,10 @@ def test_sort_argsort_chunk_size(xp, func, shape, chunks): @pytest.mark.parametrize("func", ["sort", "argsort"]) def test_sort_argsort_meta(xp, func): """Test meta-namespace other than numpy""" - typ = type(array_api_strict.asarray(0)) + mxp = pytest.importorskip("array_api_strict") + typ = type(mxp.asarray(0)) a = da.random.random(10) - b = a.map_blocks(array_api_strict.asarray) + b = a.map_blocks(mxp.asarray) assert isinstance(b._meta, typ) c = getattr(xp, func)(b) assert isinstance(c._meta, typ) diff --git a/torch-skips.txt b/torch-skips.txt index cbf7235b..e69de29b 100644 --- a/torch-skips.txt +++ b/torch-skips.txt @@ -1,11 +0,0 @@ -# These tests cause a core dump on CI, so we have to skip them entirely -array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__ifloordiv__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__ifloordiv__(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[floor_divide(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__imod__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__imod__(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__mod__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[__mod__(x1, x2)] -array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[remainder(x1, x2)] diff --git a/torch-xfails.txt b/torch-xfails.txt index e556fa4f..abee88b1 100644 --- a/torch-xfails.txt +++ b/torch-xfails.txt @@ -29,6 +29,10 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_divide[__trued array_api_tests/test_operators_and_elementwise_functions.py::test_equal[__eq__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_not_equal[__ne__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x, s)] +array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__ifloordiv__(x, s)] +array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__ifloordiv__(x1, x2)] +array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[floor_divide(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_greater[__gt__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_greater_equal[__ge__(x1, x2)] array_api_tests/test_operators_and_elementwise_functions.py::test_less[__lt__(x1, x2)] From 00e7cceb338025d9428af2bb6afbe7eaac8cf414 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Tue, 15 Apr 2025 11:53:21 +0200 Subject: [PATCH 030/151] BUG: add torch.repeat --- array_api_compat/torch/_aliases.py | 7 ++++++- torch-xfails.txt | 3 +-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index a2ed1449..0a604b8c 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -574,6 +574,11 @@ def count_nonzero( return result +# "repeat" is torch.repeat_interleave; also the dim argument +def repeat(x: Array, repeats: int | Array, /, *, axis: int | None = None) -> Array: + return torch.repeat_interleave(x, repeats, axis) + + def where( condition: Array, x1: Array | bool | int | float | complex, @@ -854,6 +859,6 @@ def sign(x: Array, /) -> Array: 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult', 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values', 'matmul', 'matrix_transpose', 'vecdot', 'tensordot', 'isdtype', - 'take', 'take_along_axis', 'sign', 'finfo', 'iinfo'] + 'take', 'take_along_axis', 'sign', 'finfo', 'iinfo', 'repeat'] _all_ignore = ['torch', 'get_xp'] diff --git a/torch-xfails.txt b/torch-xfails.txt index e556fa4f..ab11f457 100644 --- a/torch-xfails.txt +++ b/torch-xfails.txt @@ -120,9 +120,8 @@ array_api_tests/test_data_type_functions.py::test_finfo_dtype array_api_tests/test_data_type_functions.py::test_iinfo_dtype # 2023.12 support -array_api_tests/test_has_names.py::test_has_names[manipulation-repeat] +# https://github.com/pytorch/pytorch/issues/151311: torch.repeat_interleave rejects short integers array_api_tests/test_manipulation_functions.py::test_repeat -array_api_tests/test_signatures.py::test_func_signature[repeat] # Argument 'device' missing from signature array_api_tests/test_signatures.py::test_func_signature[from_dlpack] # Argument 'max_version' missing from signature From d743dc13e16a2328e3ce0951dd3633629b6537a6 Mon Sep 17 00:00:00 2001 From: Guido Imperiale Date: Tue, 15 Apr 2025 11:54:15 +0100 Subject: [PATCH 031/151] MAINT: `__array_namespace_info__` docstrings tweaks (#300) --- array_api_compat/common/_aliases.py | 2 +- array_api_compat/cupy/_info.py | 20 ++++++++++---- array_api_compat/dask/array/_info.py | 19 +++++++------ array_api_compat/numpy/_info.py | 8 +++--- array_api_compat/torch/_info.py | 41 ++++++++++++++++++---------- 5 files changed, 56 insertions(+), 34 deletions(-) diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 46cbb359..351b5bd6 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -18,7 +18,7 @@ # These functions are modified from the NumPy versions. -# Creation functions add the device keyword (which does nothing for NumPy) +# Creation functions add the device keyword (which does nothing for NumPy and Dask) def arange( start: Union[int, float], diff --git a/array_api_compat/cupy/_info.py b/array_api_compat/cupy/_info.py index 790621e4..78e48a33 100644 --- a/array_api_compat/cupy/_info.py +++ b/array_api_compat/cupy/_info.py @@ -26,6 +26,7 @@ complex128, ) + class __array_namespace_info__: """ Get the array API inspection namespace for CuPy. @@ -49,7 +50,7 @@ class __array_namespace_info__: Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.default_dtypes() {'real floating': cupy.float64, 'complex floating': cupy.complex128, @@ -94,13 +95,13 @@ def capabilities(self): >>> info = xp.__array_namespace_info__() >>> info.capabilities() {'boolean indexing': True, - 'data-dependent shapes': True} + 'data-dependent shapes': True, + 'max dimensions': 64} """ return { "boolean indexing": True, "data-dependent shapes": True, - # 'max rank' will be part of the 2024.12 standard "max dimensions": 64, } @@ -117,7 +118,7 @@ def default_device(self): Returns ------- - device : str + device : Device The default device used for new CuPy arrays. Examples @@ -126,6 +127,15 @@ def default_device(self): >>> info.default_device() Device(0) + Notes + ----- + This method returns the static default device when CuPy is initialized. + However, the *current* device used by creation functions (``empty`` etc.) + can be changed globally or with a context manager. + + See Also + -------- + https://github.com/data-apis/array-api/issues/835 """ return cuda.Device(0) @@ -312,7 +322,7 @@ def devices(self): Returns ------- - devices : list of str + devices : list[Device] The devices supported by CuPy. See Also diff --git a/array_api_compat/dask/array/_info.py b/array_api_compat/dask/array/_info.py index fc70b5a2..614f43d9 100644 --- a/array_api_compat/dask/array/_info.py +++ b/array_api_compat/dask/array/_info.py @@ -50,7 +50,7 @@ class __array_namespace_info__: Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.default_dtypes() {'real floating': dask.float64, 'complex floating': dask.complex128, @@ -103,10 +103,11 @@ def capabilities(self): Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.capabilities() {'boolean indexing': True, - 'data-dependent shapes': True} + 'data-dependent shapes': True, + 'max dimensions': 64} """ return { @@ -130,12 +131,12 @@ def default_device(self): Returns ------- - device : str + device : Device The default device used for new Dask arrays. Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.default_device() 'cpu' @@ -173,7 +174,7 @@ def default_dtypes(self, *, device=None): Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.default_dtypes() {'real floating': dask.float64, 'complex floating': dask.complex128, @@ -239,7 +240,7 @@ def dtypes(self, *, device=None, kind=None): Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.dtypes(kind='signed integer') {'int8': dask.int8, 'int16': dask.int16, @@ -335,7 +336,7 @@ def devices(self): Returns ------- - devices : list of str + devices : list[Device] The devices supported by Dask. See Also @@ -347,7 +348,7 @@ def devices(self): Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.devices() ['cpu', DASK_DEVICE] diff --git a/array_api_compat/numpy/_info.py b/array_api_compat/numpy/_info.py index e706d118..365855b8 100644 --- a/array_api_compat/numpy/_info.py +++ b/array_api_compat/numpy/_info.py @@ -94,13 +94,13 @@ def capabilities(self): >>> info = np.__array_namespace_info__() >>> info.capabilities() {'boolean indexing': True, - 'data-dependent shapes': True} + 'data-dependent shapes': True, + 'max dimensions': 64} """ return { "boolean indexing": True, "data-dependent shapes": True, - # 'max rank' will be part of the 2024.12 standard "max dimensions": 64, } @@ -119,7 +119,7 @@ def default_device(self): Returns ------- - device : str + device : Device The default device used for new NumPy arrays. Examples @@ -326,7 +326,7 @@ def devices(self): Returns ------- - devices : list of str + devices : list[Device] The devices supported by NumPy. See Also diff --git a/array_api_compat/torch/_info.py b/array_api_compat/torch/_info.py index 34fbcb21..818e5d37 100644 --- a/array_api_compat/torch/_info.py +++ b/array_api_compat/torch/_info.py @@ -34,7 +34,7 @@ class __array_namespace_info__: Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.default_dtypes() {'real floating': numpy.float64, 'complex floating': numpy.complex128, @@ -76,16 +76,16 @@ def capabilities(self): Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.capabilities() {'boolean indexing': True, - 'data-dependent shapes': True} + 'data-dependent shapes': True, + 'max dimensions': 64} """ return { "boolean indexing": True, "data-dependent shapes": True, - # 'max rank' will be part of the 2024.12 standard "max dimensions": 64, } @@ -102,15 +102,24 @@ def default_device(self): Returns ------- - device : str + device : Device The default device used for new PyTorch arrays. Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.default_device() - 'cpu' + device(type='cpu') + Notes + ----- + This method returns the static default device when PyTorch is initialized. + However, the *current* device used by creation functions (``empty`` etc.) + can be changed at runtime. + + See Also + -------- + https://github.com/data-apis/array-api/issues/835 """ return torch.device("cpu") @@ -120,9 +129,9 @@ def default_dtypes(self, *, device=None): Parameters ---------- - device : str, optional - The device to get the default data types for. For PyTorch, only - ``'cpu'`` is allowed. + device : Device, optional + The device to get the default data types for. + Unused for PyTorch, as all devices use the same default dtypes. Returns ------- @@ -139,7 +148,7 @@ def default_dtypes(self, *, device=None): Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.default_dtypes() {'real floating': torch.float32, 'complex floating': torch.complex64, @@ -250,8 +259,9 @@ def dtypes(self, *, device=None, kind=None): Parameters ---------- - device : str, optional + device : Device, optional The device to get the data types for. + Unused for PyTorch, as all devices use the same dtypes. kind : str or tuple of str, optional The kind of data types to return. If ``None``, all data types are returned. If a string, only data types of that kind are returned. @@ -287,7 +297,7 @@ def dtypes(self, *, device=None, kind=None): Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.dtypes(kind='signed integer') {'int8': numpy.int8, 'int16': numpy.int16, @@ -310,7 +320,7 @@ def devices(self): Returns ------- - devices : list of str + devices : list[Device] The devices supported by PyTorch. See Also @@ -322,7 +332,7 @@ def devices(self): Examples -------- - >>> info = np.__array_namespace_info__() + >>> info = xp.__array_namespace_info__() >>> info.devices() [device(type='cpu'), device(type='mps', index=0), device(type='meta')] @@ -333,6 +343,7 @@ def devices(self): # device: try: torch.device('notadevice') + raise AssertionError("unreachable") # pragma: nocover except RuntimeError as e: # The error message is something like: # "Expected one of cpu, cuda, ipu, xpu, mkldnn, opengl, opencl, ideep, hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, hpu, mtia, privateuseone device type at start of device string: notadevice" From 9194c5cb7706e08f1a1092aece1fce76ac6e089a Mon Sep 17 00:00:00 2001 From: Guido Imperiale Date: Tue, 15 Apr 2025 12:04:09 +0100 Subject: [PATCH 032/151] MAINT: simplify `torch` dtype promotion (#303) reviewed at https://github.com/data-apis/array-api-compat/pull/303 --- array_api_compat/torch/_aliases.py | 99 ++++++++++++------------------ 1 file changed, 40 insertions(+), 59 deletions(-) diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index a2ed1449..5370803f 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -35,47 +35,23 @@ torch.complex128, } -_promotion_table = { - # bool - (torch.bool, torch.bool): torch.bool, +_promotion_table = { # ints - (torch.int8, torch.int8): torch.int8, (torch.int8, torch.int16): torch.int16, (torch.int8, torch.int32): torch.int32, (torch.int8, torch.int64): torch.int64, - (torch.int16, torch.int8): torch.int16, - (torch.int16, torch.int16): torch.int16, (torch.int16, torch.int32): torch.int32, (torch.int16, torch.int64): torch.int64, - (torch.int32, torch.int8): torch.int32, - (torch.int32, torch.int16): torch.int32, - (torch.int32, torch.int32): torch.int32, (torch.int32, torch.int64): torch.int64, - (torch.int64, torch.int8): torch.int64, - (torch.int64, torch.int16): torch.int64, - (torch.int64, torch.int32): torch.int64, - (torch.int64, torch.int64): torch.int64, - # uints - (torch.uint8, torch.uint8): torch.uint8, # ints and uints (mixed sign) - (torch.int8, torch.uint8): torch.int16, - (torch.int16, torch.uint8): torch.int16, - (torch.int32, torch.uint8): torch.int32, - (torch.int64, torch.uint8): torch.int64, (torch.uint8, torch.int8): torch.int16, (torch.uint8, torch.int16): torch.int16, (torch.uint8, torch.int32): torch.int32, (torch.uint8, torch.int64): torch.int64, # floats - (torch.float32, torch.float32): torch.float32, (torch.float32, torch.float64): torch.float64, - (torch.float64, torch.float32): torch.float64, - (torch.float64, torch.float64): torch.float64, # complexes - (torch.complex64, torch.complex64): torch.complex64, (torch.complex64, torch.complex128): torch.complex128, - (torch.complex128, torch.complex64): torch.complex128, - (torch.complex128, torch.complex128): torch.complex128, # Mixed float and complex (torch.float32, torch.complex64): torch.complex64, (torch.float32, torch.complex128): torch.complex128, @@ -83,6 +59,9 @@ (torch.float64, torch.complex128): torch.complex128, } +_promotion_table.update({(b, a): c for (a, b), c in _promotion_table.items()}) +_promotion_table.update({(a, a): a for a in _array_api_dtypes}) + def _two_arg(f): @_wraps(f) @@ -150,13 +129,18 @@ def result_type( return _reduce(_result_type, others + scalars) -def _result_type(x, y): +def _result_type( + x: Array | DType | bool | int | float | complex, + y: Array | DType | bool | int | float | complex, +) -> DType: if not (isinstance(x, _py_scalars) or isinstance(y, _py_scalars)): - xdt = x.dtype if not isinstance(x, torch.dtype) else x - ydt = y.dtype if not isinstance(y, torch.dtype) else y + xdt = x if isinstance(x, torch.dtype) else x.dtype + ydt = y if isinstance(y, torch.dtype) else y.dtype - if (xdt, ydt) in _promotion_table: + try: return _promotion_table[xdt, ydt] + except KeyError: + pass # This doesn't result_type(dtype, dtype) for non-array API dtypes # because torch.result_type only accepts tensors. This does however, allow @@ -301,6 +285,25 @@ def _reduce_multiple_axes(f, x, axis, keepdims=False, **kwargs): out = torch.unsqueeze(out, a) return out + +def _sum_prod_no_axis(x: Array, dtype: DType | None) -> Array: + """ + Implements `sum(..., axis=())` and `prod(..., axis=())`. + + Works around https://github.com/pytorch/pytorch/issues/29137 + """ + if dtype is not None: + return x.clone() if dtype == x.dtype else x.to(dtype) + + # We can't upcast uint8 according to the spec because there is no + # torch.uint64, so at least upcast to int64 which is what prod does + # when axis=None. + if x.dtype in (torch.uint8, torch.int8, torch.int16, torch.int32): + return x.to(torch.int64) + + return x.clone() + + def prod(x: Array, /, *, @@ -308,20 +311,9 @@ def prod(x: Array, dtype: Optional[DType] = None, keepdims: bool = False, **kwargs) -> Array: - ndim = x.ndim - # https://github.com/pytorch/pytorch/issues/29137. Separate from the logic - # below because it still needs to upcast. if axis == (): - if dtype is None: - # We can't upcast uint8 according to the spec because there is no - # torch.uint64, so at least upcast to int64 which is what sum does - # when axis=None. - if x.dtype in [torch.int8, torch.int16, torch.int32, torch.uint8]: - return x.to(torch.int64) - return x.clone() - return x.to(dtype) - + return _sum_prod_no_axis(x, dtype) # torch.prod doesn't support multiple axes # (https://github.com/pytorch/pytorch/issues/56586). if isinstance(axis, tuple): @@ -330,7 +322,7 @@ def prod(x: Array, # torch doesn't support keepdims with axis=None # (https://github.com/pytorch/pytorch/issues/71209) res = torch.prod(x, dtype=dtype, **kwargs) - res = _axis_none_keepdims(res, ndim, keepdims) + res = _axis_none_keepdims(res, x.ndim, keepdims) return res return torch.prod(x, axis, dtype=dtype, keepdims=keepdims, **kwargs) @@ -343,25 +335,14 @@ def sum(x: Array, dtype: Optional[DType] = None, keepdims: bool = False, **kwargs) -> Array: - ndim = x.ndim - # https://github.com/pytorch/pytorch/issues/29137. - # Make sure it upcasts. if axis == (): - if dtype is None: - # We can't upcast uint8 according to the spec because there is no - # torch.uint64, so at least upcast to int64 which is what sum does - # when axis=None. - if x.dtype in [torch.int8, torch.int16, torch.int32, torch.uint8]: - return x.to(torch.int64) - return x.clone() - return x.to(dtype) - + return _sum_prod_no_axis(x, dtype) if axis is None: # torch doesn't support keepdims with axis=None # (https://github.com/pytorch/pytorch/issues/71209) res = torch.sum(x, dtype=dtype, **kwargs) - res = _axis_none_keepdims(res, ndim, keepdims) + res = _axis_none_keepdims(res, x.ndim, keepdims) return res return torch.sum(x, axis, dtype=dtype, keepdims=keepdims, **kwargs) @@ -372,7 +353,7 @@ def any(x: Array, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, **kwargs) -> Array: - ndim = x.ndim + if axis == (): return x.to(torch.bool) # torch.any doesn't support multiple axes @@ -384,7 +365,7 @@ def any(x: Array, # torch doesn't support keepdims with axis=None # (https://github.com/pytorch/pytorch/issues/71209) res = torch.any(x, **kwargs) - res = _axis_none_keepdims(res, ndim, keepdims) + res = _axis_none_keepdims(res, x.ndim, keepdims) return res.to(torch.bool) # torch.any doesn't return bool for uint8 @@ -396,7 +377,7 @@ def all(x: Array, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, **kwargs) -> Array: - ndim = x.ndim + if axis == (): return x.to(torch.bool) # torch.all doesn't support multiple axes @@ -408,7 +389,7 @@ def all(x: Array, # torch doesn't support keepdims with axis=None # (https://github.com/pytorch/pytorch/issues/71209) res = torch.all(x, **kwargs) - res = _axis_none_keepdims(res, ndim, keepdims) + res = _axis_none_keepdims(res, x.ndim, keepdims) return res.to(torch.bool) # torch.all doesn't return bool for uint8 From b94efc1f5e490a23c0ca74aafb93cc3118471f46 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Tue, 15 Apr 2025 14:37:45 +0200 Subject: [PATCH 033/151] TST: skip testing nextafter with scalars on torch --- torch-xfails.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/torch-xfails.txt b/torch-xfails.txt index f8333d90..538403a3 100644 --- a/torch-xfails.txt +++ b/torch-xfails.txt @@ -144,6 +144,7 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_sc array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[logaddexp] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[maximum] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[minimum] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[nextafter] # https://github.com/pytorch/pytorch/issues/149815 array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[equal] From 205c967d658de24b2738dcae8d91684a1f99d2cd Mon Sep 17 00:00:00 2001 From: Joren Hammudoglu Date: Thu, 17 Apr 2025 20:14:48 +0200 Subject: [PATCH 034/151] TYP: Type annotations overhaul, episode 2 (#288) * TYP: annotate `_internal.get_xp` (and curse at `ParamSpec` for being so useless) * TYP: fix (or ignore) typing errors in `common._helpers` (and curse at cupy) * TYP: fix typing errors in `common._fft` * TYP: fix typing errors in `common._aliases` * TYP: fix typing errors in `common._linalg` * TYP: fix/ignore typing errors in `numpy.__init__` * TYP: fix typing errors in `numpy._typing` * TYP: fix typing errors in `numpy._aliases` * TYP: fix typing errors in `numpy._info` * TYP: fix typing errors in `numpy._fft` * TYP: it's a bad idea to import `TypeAlias` from `typing` on `python<3.10` * TYP: it's also a bad idea to import `TypeGuard` from `typing` on `python<3.10` * TYP: don't scare the prehistoric `dtype` from numpy 1.21 * TYP: dust off the DeLorean * TYP: figure out how to drive a DeLorean * TYP: apply review suggestions Co-authored-by: crusaderky * TYP: sprinkle some `TypeAlias`es and `Final`s around * TYP: `__dir__` * TYP: fix typing errors in `numpy.linalg` * TYP: add a `common._typing.Capabilities` typed dict type * TYP: `__array_namespace_info__` helper types * TYP: `dask.array` typing fixes and improvements * STY: give the `=` some breathing room Co-authored-by: Lucas Colley * STY: apply review suggestions Co-authored-by: lucascolley --------- Co-authored-by: crusaderky Co-authored-by: Lucas Colley --- array_api_compat/_internal.py | 25 +- array_api_compat/common/__init__.py | 2 +- array_api_compat/common/_aliases.py | 331 +++++++++++++++--------- array_api_compat/common/_fft.py | 69 ++--- array_api_compat/common/_helpers.py | 287 +++++++++++++------- array_api_compat/common/_linalg.py | 110 ++++++-- array_api_compat/common/_typing.py | 148 ++++++++++- array_api_compat/dask/array/__init__.py | 8 +- array_api_compat/dask/array/_aliases.py | 162 +++++++----- array_api_compat/dask/array/_info.py | 96 +++++-- array_api_compat/dask/array/linalg.py | 22 +- array_api_compat/numpy/__init__.py | 28 +- array_api_compat/numpy/_aliases.py | 86 +++--- array_api_compat/numpy/_info.py | 42 ++- array_api_compat/numpy/_typing.py | 35 ++- array_api_compat/numpy/fft.py | 16 +- array_api_compat/numpy/linalg.py | 97 +++++-- 17 files changed, 1076 insertions(+), 488 deletions(-) diff --git a/array_api_compat/_internal.py b/array_api_compat/_internal.py index 170a1ff9..cd8d939f 100644 --- a/array_api_compat/_internal.py +++ b/array_api_compat/_internal.py @@ -2,10 +2,16 @@ Internal helpers """ +from collections.abc import Callable from functools import wraps from inspect import signature +from types import ModuleType +from typing import TypeVar -def get_xp(xp): +_T = TypeVar("_T") + + +def get_xp(xp: ModuleType) -> Callable[[Callable[..., _T]], Callable[..., _T]]: """ Decorator to automatically replace xp with the corresponding array module. @@ -22,14 +28,14 @@ def func(x, /, xp, kwarg=None): """ - def inner(f): + def inner(f: Callable[..., _T], /) -> Callable[..., _T]: @wraps(f) - def wrapped_f(*args, **kwargs): + def wrapped_f(*args: object, **kwargs: object) -> object: return f(*args, xp=xp, **kwargs) sig = signature(f) new_sig = sig.replace( - parameters=[sig.parameters[i] for i in sig.parameters if i != "xp"] + parameters=[par for i, par in sig.parameters.items() if i != "xp"] ) if wrapped_f.__doc__ is None: @@ -40,7 +46,14 @@ def wrapped_f(*args, **kwargs): specification for more details. """ - wrapped_f.__signature__ = new_sig - return wrapped_f + wrapped_f.__signature__ = new_sig # pyright: ignore[reportAttributeAccessIssue] + return wrapped_f # pyright: ignore[reportReturnType] return inner + + +__all__ = ["get_xp"] + + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/common/__init__.py b/array_api_compat/common/__init__.py index 91ab1c40..82360807 100644 --- a/array_api_compat/common/__init__.py +++ b/array_api_compat/common/__init__.py @@ -1 +1 @@ -from ._helpers import * # noqa: F403 +from ._helpers import * # noqa: F403 diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 351b5bd6..8ea9162a 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -5,158 +5,170 @@ from __future__ import annotations import inspect -from typing import Any, NamedTuple, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Sequence, cast +from ._helpers import _check_device, array_namespace +from ._helpers import device as _get_device +from ._helpers import is_cupy_namespace as _is_cupy_namespace from ._typing import Array, Device, DType, Namespace -from ._helpers import ( - array_namespace, - _check_device, - device as _get_device, - is_cupy_namespace as _is_cupy_namespace -) +if TYPE_CHECKING: + # TODO: import from typing (requires Python >=3.13) + from typing_extensions import TypeIs # These functions are modified from the NumPy versions. # Creation functions add the device keyword (which does nothing for NumPy and Dask) + def arange( - start: Union[int, float], + start: float, /, - stop: Optional[Union[int, float]] = None, - step: Union[int, float] = 1, + stop: float | None = None, + step: float = 1, *, xp: Namespace, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs, + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, ) -> Array: _check_device(xp, device) return xp.arange(start, stop=stop, step=step, dtype=dtype, **kwargs) + def empty( - shape: Union[int, Tuple[int, ...]], + shape: int | tuple[int, ...], xp: Namespace, *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs, + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, ) -> Array: _check_device(xp, device) return xp.empty(shape, dtype=dtype, **kwargs) + def empty_like( x: Array, /, xp: Namespace, *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs, + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, ) -> Array: _check_device(xp, device) return xp.empty_like(x, dtype=dtype, **kwargs) + def eye( n_rows: int, - n_cols: Optional[int] = None, + n_cols: int | None = None, /, *, xp: Namespace, k: int = 0, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs, + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, ) -> Array: _check_device(xp, device) return xp.eye(n_rows, M=n_cols, k=k, dtype=dtype, **kwargs) + def full( - shape: Union[int, Tuple[int, ...]], - fill_value: bool | int | float | complex, + shape: int | tuple[int, ...], + fill_value: complex, xp: Namespace, *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs, + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, ) -> Array: _check_device(xp, device) return xp.full(shape, fill_value, dtype=dtype, **kwargs) + def full_like( x: Array, /, - fill_value: bool | int | float | complex, + fill_value: complex, *, xp: Namespace, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs, + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, ) -> Array: _check_device(xp, device) return xp.full_like(x, fill_value, dtype=dtype, **kwargs) + def linspace( - start: Union[int, float], - stop: Union[int, float], + start: float, + stop: float, /, num: int, *, xp: Namespace, - dtype: Optional[DType] = None, - device: Optional[Device] = None, + dtype: DType | None = None, + device: Device | None = None, endpoint: bool = True, - **kwargs, + **kwargs: object, ) -> Array: _check_device(xp, device) return xp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint, **kwargs) + def ones( - shape: Union[int, Tuple[int, ...]], + shape: int | tuple[int, ...], xp: Namespace, *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs, + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, ) -> Array: _check_device(xp, device) return xp.ones(shape, dtype=dtype, **kwargs) + def ones_like( x: Array, /, xp: Namespace, *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs, + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, ) -> Array: _check_device(xp, device) return xp.ones_like(x, dtype=dtype, **kwargs) + def zeros( - shape: Union[int, Tuple[int, ...]], + shape: int | tuple[int, ...], xp: Namespace, *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs, + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, ) -> Array: _check_device(xp, device) return xp.zeros(shape, dtype=dtype, **kwargs) + def zeros_like( x: Array, /, xp: Namespace, *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs, + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, ) -> Array: _check_device(xp, device) return xp.zeros_like(x, dtype=dtype, **kwargs) + # np.unique() is split into four functions in the array API: # unique_all, unique_counts, unique_inverse, and unique_values (this is done # to remove polymorphic return types). @@ -164,6 +176,7 @@ def zeros_like( # The functions here return namedtuples (np.unique() returns a normal # tuple). + # Note that these named tuples aren't actually part of the standard namespace, # but I don't see any issue with exporting the names here regardless. class UniqueAllResult(NamedTuple): @@ -188,10 +201,11 @@ def _unique_kwargs(xp: Namespace) -> dict[str, bool]: # trying to parse version numbers, just check if equal_nan is in the # signature. s = inspect.signature(xp.unique) - if 'equal_nan' in s.parameters: - return {'equal_nan': False} + if "equal_nan" in s.parameters: + return {"equal_nan": False} return {} + def unique_all(x: Array, /, xp: Namespace) -> UniqueAllResult: kwargs = _unique_kwargs(xp) values, indices, inverse_indices, counts = xp.unique( @@ -215,11 +229,7 @@ def unique_all(x: Array, /, xp: Namespace) -> UniqueAllResult: def unique_counts(x: Array, /, xp: Namespace) -> UniqueCountsResult: kwargs = _unique_kwargs(xp) res = xp.unique( - x, - return_counts=True, - return_index=False, - return_inverse=False, - **kwargs + x, return_counts=True, return_index=False, return_inverse=False, **kwargs ) return UniqueCountsResult(*res) @@ -250,51 +260,58 @@ def unique_values(x: Array, /, xp: Namespace) -> Array: **kwargs, ) + # These functions have different keyword argument names + def std( x: Array, /, xp: Namespace, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - correction: Union[int, float] = 0.0, # correction instead of ddof + axis: int | tuple[int, ...] | None = None, + correction: float = 0.0, # correction instead of ddof keepdims: bool = False, - **kwargs, + **kwargs: object, ) -> Array: return xp.std(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs) + def var( x: Array, /, xp: Namespace, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - correction: Union[int, float] = 0.0, # correction instead of ddof + axis: int | tuple[int, ...] | None = None, + correction: float = 0.0, # correction instead of ddof keepdims: bool = False, - **kwargs, + **kwargs: object, ) -> Array: return xp.var(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs) + # cumulative_sum is renamed from cumsum, and adds the include_initial keyword # argument + def cumulative_sum( x: Array, /, xp: Namespace, *, - axis: Optional[int] = None, - dtype: Optional[DType] = None, + axis: int | None = None, + dtype: DType | None = None, include_initial: bool = False, - **kwargs, + **kwargs: object, ) -> Array: wrapped_xp = array_namespace(x) # TODO: The standard is not clear about what should happen when x.ndim == 0. if axis is None: if x.ndim > 1: - raise ValueError("axis must be specified in cumulative_sum for more than one dimension") + raise ValueError( + "axis must be specified in cumulative_sum for more than one dimension" + ) axis = 0 res = xp.cumsum(x, axis=axis, dtype=dtype, **kwargs) @@ -304,7 +321,12 @@ def cumulative_sum( initial_shape = list(x.shape) initial_shape[axis] = 1 res = xp.concatenate( - [wrapped_xp.zeros(shape=initial_shape, dtype=res.dtype, device=_get_device(res)), res], + [ + wrapped_xp.zeros( + shape=initial_shape, dtype=res.dtype, device=_get_device(res) + ), + res, + ], axis=axis, ) return res @@ -315,16 +337,18 @@ def cumulative_prod( /, xp: Namespace, *, - axis: Optional[int] = None, - dtype: Optional[DType] = None, + axis: int | None = None, + dtype: DType | None = None, include_initial: bool = False, - **kwargs, + **kwargs: object, ) -> Array: wrapped_xp = array_namespace(x) if axis is None: if x.ndim > 1: - raise ValueError("axis must be specified in cumulative_prod for more than one dimension") + raise ValueError( + "axis must be specified in cumulative_prod for more than one dimension" + ) axis = 0 res = xp.cumprod(x, axis=axis, dtype=dtype, **kwargs) @@ -334,24 +358,30 @@ def cumulative_prod( initial_shape = list(x.shape) initial_shape[axis] = 1 res = xp.concatenate( - [wrapped_xp.ones(shape=initial_shape, dtype=res.dtype, device=_get_device(res)), res], + [ + wrapped_xp.ones( + shape=initial_shape, dtype=res.dtype, device=_get_device(res) + ), + res, + ], axis=axis, ) return res + # The min and max argument names in clip are different and not optional in numpy, and type # promotion behavior is different. def clip( x: Array, /, - min: Optional[Union[int, float, Array]] = None, - max: Optional[Union[int, float, Array]] = None, + min: float | Array | None = None, + max: float | Array | None = None, *, xp: Namespace, # TODO: np.clip has other ufunc kwargs - out: Optional[Array] = None, + out: Array | None = None, ) -> Array: - def _isscalar(a): + def _isscalar(a: object) -> TypeIs[int | float | None]: return isinstance(a, (int, float, type(None))) min_shape = () if _isscalar(min) else min.shape @@ -378,7 +408,6 @@ def _isscalar(a): # but an answer of 0 might be preferred. See # https://github.com/numpy/numpy/issues/24976 for more discussion on this issue. - # At least handle the case of Python integers correctly (see # https://github.com/numpy/numpy/pull/26892). if wrapped_xp.isdtype(x.dtype, "integral"): @@ -390,6 +419,7 @@ def _isscalar(a): dev = _get_device(x) if out is None: out = wrapped_xp.empty(result_shape, dtype=x.dtype, device=dev) + assert out is not None # workaround for a type-narrowing issue in pyright out[()] = x if min is not None: @@ -407,19 +437,21 @@ def _isscalar(a): # Return a scalar for 0-D return out[()] + # Unlike transpose(), the axes argument to permute_dims() is required. -def permute_dims(x: Array, /, axes: Tuple[int, ...], xp: Namespace) -> Array: +def permute_dims(x: Array, /, axes: tuple[int, ...], xp: Namespace) -> Array: return xp.transpose(x, axes) + # np.reshape calls the keyword argument 'newshape' instead of 'shape' def reshape( x: Array, /, - shape: Tuple[int, ...], + shape: tuple[int, ...], xp: Namespace, *, copy: Optional[bool] = None, - **kwargs, + **kwargs: object, ) -> Array: if copy is True: x = x.copy() @@ -429,6 +461,7 @@ def reshape( return y return xp.reshape(x, shape, **kwargs) + # The descending keyword is new in sort and argsort, and 'kind' replaced with # 'stable' def argsort( @@ -439,13 +472,13 @@ def argsort( axis: int = -1, descending: bool = False, stable: bool = True, - **kwargs, + **kwargs: object, ) -> Array: # Note: this keyword argument is different, and the default is different. # We set it in kwargs like this because numpy.sort uses kind='quicksort' # as the default whereas cupy.sort uses kind=None. if stable: - kwargs['kind'] = "stable" + kwargs["kind"] = "stable" if not descending: res = xp.argsort(x, axis=axis, **kwargs) else: @@ -462,6 +495,7 @@ def argsort( res = max_i - res return res + def sort( x: Array, /, @@ -470,68 +504,78 @@ def sort( axis: int = -1, descending: bool = False, stable: bool = True, - **kwargs, + **kwargs: object, ) -> Array: # Note: this keyword argument is different, and the default is different. # We set it in kwargs like this because numpy.sort uses kind='quicksort' # as the default whereas cupy.sort uses kind=None. if stable: - kwargs['kind'] = "stable" + kwargs["kind"] = "stable" res = xp.sort(x, axis=axis, **kwargs) if descending: res = xp.flip(res, axis=axis) return res + # nonzero should error for zero-dimensional arrays -def nonzero(x: Array, /, xp: Namespace, **kwargs) -> Tuple[Array, ...]: +def nonzero(x: Array, /, xp: Namespace, **kwargs: object) -> tuple[Array, ...]: if x.ndim == 0: raise ValueError("nonzero() does not support zero-dimensional arrays") return xp.nonzero(x, **kwargs) + # ceil, floor, and trunc return integers for integer inputs -def ceil(x: Array, /, xp: Namespace, **kwargs) -> Array: + +def ceil(x: Array, /, xp: Namespace, **kwargs: object) -> Array: if xp.issubdtype(x.dtype, xp.integer): return x return xp.ceil(x, **kwargs) -def floor(x: Array, /, xp: Namespace, **kwargs) -> Array: + +def floor(x: Array, /, xp: Namespace, **kwargs: object) -> Array: if xp.issubdtype(x.dtype, xp.integer): return x return xp.floor(x, **kwargs) -def trunc(x: Array, /, xp: Namespace, **kwargs) -> Array: + +def trunc(x: Array, /, xp: Namespace, **kwargs: object) -> Array: if xp.issubdtype(x.dtype, xp.integer): return x return xp.trunc(x, **kwargs) + # linear algebra functions -def matmul(x1: Array, x2: Array, /, xp: Namespace, **kwargs) -> Array: + +def matmul(x1: Array, x2: Array, /, xp: Namespace, **kwargs: object) -> Array: return xp.matmul(x1, x2, **kwargs) + # Unlike transpose, matrix_transpose only transposes the last two axes. def matrix_transpose(x: Array, /, xp: Namespace) -> Array: if x.ndim < 2: raise ValueError("x must be at least 2-dimensional for matrix_transpose") return xp.swapaxes(x, -1, -2) + def tensordot( x1: Array, x2: Array, /, xp: Namespace, *, - axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2, - **kwargs, + axes: int | tuple[Sequence[int], Sequence[int]] = 2, + **kwargs: object, ) -> Array: return xp.tensordot(x1, x2, axes=axes, **kwargs) + def vecdot(x1: Array, x2: Array, /, xp: Namespace, *, axis: int = -1) -> Array: if x1.shape[axis] != x2.shape[axis]: raise ValueError("x1 and x2 must have the same size along the given axis") - if hasattr(xp, 'broadcast_tensors'): + if hasattr(xp, "broadcast_tensors"): _broadcast = xp.broadcast_tensors else: _broadcast = xp.broadcast_arrays @@ -543,14 +587,16 @@ def vecdot(x1: Array, x2: Array, /, xp: Namespace, *, axis: int = -1) -> Array: res = xp.conj(x1_[..., None, :]) @ x2_[..., None] return res[..., 0, 0] + # isdtype is a new function in the 2022.12 array API specification. + def isdtype( dtype: DType, - kind: Union[DType, str, Tuple[Union[DType, str], ...]], + kind: DType | str | tuple[DType | str, ...], xp: Namespace, *, - _tuple: bool = True, # Disallow nested tuples + _tuple: bool = True, # Disallow nested tuples ) -> bool: """ Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``. @@ -563,21 +609,24 @@ def isdtype( for more details """ if isinstance(kind, tuple) and _tuple: - return any(isdtype(dtype, k, xp, _tuple=False) for k in kind) + return any( + isdtype(dtype, k, xp, _tuple=False) + for k in cast("tuple[DType | str, ...]", kind) + ) elif isinstance(kind, str): - if kind == 'bool': + if kind == "bool": return dtype == xp.bool_ - elif kind == 'signed integer': + elif kind == "signed integer": return xp.issubdtype(dtype, xp.signedinteger) - elif kind == 'unsigned integer': + elif kind == "unsigned integer": return xp.issubdtype(dtype, xp.unsignedinteger) - elif kind == 'integral': + elif kind == "integral": return xp.issubdtype(dtype, xp.integer) - elif kind == 'real floating': + elif kind == "real floating": return xp.issubdtype(dtype, xp.floating) - elif kind == 'complex floating': + elif kind == "complex floating": return xp.issubdtype(dtype, xp.complexfloating) - elif kind == 'numeric': + elif kind == "numeric": return xp.issubdtype(dtype, xp.number) else: raise ValueError(f"Unrecognized data type kind: {kind!r}") @@ -588,24 +637,27 @@ def isdtype( # array_api_strict implementation will be very strict. return dtype == kind + # unstack is a new function in the 2023.12 array API standard -def unstack(x: Array, /, xp: Namespace, *, axis: int = 0) -> Tuple[Array, ...]: +def unstack(x: Array, /, xp: Namespace, *, axis: int = 0) -> tuple[Array, ...]: if x.ndim == 0: raise ValueError("Input array must be at least 1-d.") return tuple(xp.moveaxis(x, axis, 0)) + # numpy 1.26 does not use the standard definition for sign on complex numbers -def sign(x: Array, /, xp: Namespace, **kwargs) -> Array: - if isdtype(x.dtype, 'complex floating', xp=xp): - out = (x/xp.abs(x, **kwargs))[...] + +def sign(x: Array, /, xp: Namespace, **kwargs: object) -> Array: + if isdtype(x.dtype, "complex floating", xp=xp): + out = (x / xp.abs(x, **kwargs))[...] # sign(0) = 0 but the above formula would give nan - out[x == 0+0j] = 0+0j + out[x == 0j] = 0j else: out = xp.sign(x, **kwargs) # CuPy sign() does not propagate nans. See # https://github.com/data-apis/array-api-compat/issues/136 - if _is_cupy_namespace(xp) and isdtype(x.dtype, 'real floating', xp=xp): + if _is_cupy_namespace(xp) and isdtype(x.dtype, "real floating", xp=xp): out[xp.isnan(x)] = xp.nan return out[()] @@ -626,13 +678,50 @@ def iinfo(type_: DType | Array, /, xp: Namespace) -> Any: return xp.iinfo(type_.dtype) -__all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like', - 'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like', - 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult', - 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values', - 'std', 'var', 'cumulative_sum', 'cumulative_prod','clip', 'permute_dims', - 'reshape', 'argsort', 'sort', 'nonzero', 'ceil', 'floor', 'trunc', - 'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype', - 'unstack', 'sign', 'finfo', 'iinfo'] - -_all_ignore = ['inspect', 'array_namespace', 'NamedTuple'] +__all__ = [ + "arange", + "empty", + "empty_like", + "eye", + "full", + "full_like", + "linspace", + "ones", + "ones_like", + "zeros", + "zeros_like", + "UniqueAllResult", + "UniqueCountsResult", + "UniqueInverseResult", + "unique_all", + "unique_counts", + "unique_inverse", + "unique_values", + "std", + "var", + "cumulative_sum", + "cumulative_prod", + "clip", + "permute_dims", + "reshape", + "argsort", + "sort", + "nonzero", + "ceil", + "floor", + "trunc", + "matmul", + "matrix_transpose", + "tensordot", + "vecdot", + "isdtype", + "unstack", + "sign", + "finfo", + "iinfo", +] +_all_ignore = ["inspect", "array_namespace", "NamedTuple"] + + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/common/_fft.py b/array_api_compat/common/_fft.py index bd2a4e1a..18839d37 100644 --- a/array_api_compat/common/_fft.py +++ b/array_api_compat/common/_fft.py @@ -1,9 +1,11 @@ from __future__ import annotations from collections.abc import Sequence -from typing import Union, Optional, Literal +from typing import Literal, TypeAlias -from ._typing import Device, Array, DType, Namespace +from ._typing import Array, Device, DType, Namespace + +_Norm: TypeAlias = Literal["backward", "ortho", "forward"] # Note: NumPy fft functions improperly upcast float32 and complex64 to # complex128, which is why we require wrapping them all here. @@ -13,9 +15,9 @@ def fft( /, xp: Namespace, *, - n: Optional[int] = None, + n: int | None = None, axis: int = -1, - norm: Literal["backward", "ortho", "forward"] = "backward", + norm: _Norm = "backward", ) -> Array: res = xp.fft.fft(x, n=n, axis=axis, norm=norm) if x.dtype in [xp.float32, xp.complex64]: @@ -27,9 +29,9 @@ def ifft( /, xp: Namespace, *, - n: Optional[int] = None, + n: int | None = None, axis: int = -1, - norm: Literal["backward", "ortho", "forward"] = "backward", + norm: _Norm = "backward", ) -> Array: res = xp.fft.ifft(x, n=n, axis=axis, norm=norm) if x.dtype in [xp.float32, xp.complex64]: @@ -41,9 +43,9 @@ def fftn( /, xp: Namespace, *, - s: Sequence[int] = None, - axes: Sequence[int] = None, - norm: Literal["backward", "ortho", "forward"] = "backward", + s: Sequence[int] | None = None, + axes: Sequence[int] | None = None, + norm: _Norm = "backward", ) -> Array: res = xp.fft.fftn(x, s=s, axes=axes, norm=norm) if x.dtype in [xp.float32, xp.complex64]: @@ -55,9 +57,9 @@ def ifftn( /, xp: Namespace, *, - s: Sequence[int] = None, - axes: Sequence[int] = None, - norm: Literal["backward", "ortho", "forward"] = "backward", + s: Sequence[int] | None = None, + axes: Sequence[int] | None = None, + norm: _Norm = "backward", ) -> Array: res = xp.fft.ifftn(x, s=s, axes=axes, norm=norm) if x.dtype in [xp.float32, xp.complex64]: @@ -69,9 +71,9 @@ def rfft( /, xp: Namespace, *, - n: Optional[int] = None, + n: int | None = None, axis: int = -1, - norm: Literal["backward", "ortho", "forward"] = "backward", + norm: _Norm = "backward", ) -> Array: res = xp.fft.rfft(x, n=n, axis=axis, norm=norm) if x.dtype == xp.float32: @@ -83,9 +85,9 @@ def irfft( /, xp: Namespace, *, - n: Optional[int] = None, + n: int | None = None, axis: int = -1, - norm: Literal["backward", "ortho", "forward"] = "backward", + norm: _Norm = "backward", ) -> Array: res = xp.fft.irfft(x, n=n, axis=axis, norm=norm) if x.dtype == xp.complex64: @@ -97,9 +99,9 @@ def rfftn( /, xp: Namespace, *, - s: Sequence[int] = None, - axes: Sequence[int] = None, - norm: Literal["backward", "ortho", "forward"] = "backward", + s: Sequence[int] | None = None, + axes: Sequence[int] | None = None, + norm: _Norm = "backward", ) -> Array: res = xp.fft.rfftn(x, s=s, axes=axes, norm=norm) if x.dtype == xp.float32: @@ -111,9 +113,9 @@ def irfftn( /, xp: Namespace, *, - s: Sequence[int] = None, - axes: Sequence[int] = None, - norm: Literal["backward", "ortho", "forward"] = "backward", + s: Sequence[int] | None = None, + axes: Sequence[int] | None = None, + norm: _Norm = "backward", ) -> Array: res = xp.fft.irfftn(x, s=s, axes=axes, norm=norm) if x.dtype == xp.complex64: @@ -125,9 +127,9 @@ def hfft( /, xp: Namespace, *, - n: Optional[int] = None, + n: int | None = None, axis: int = -1, - norm: Literal["backward", "ortho", "forward"] = "backward", + norm: _Norm = "backward", ) -> Array: res = xp.fft.hfft(x, n=n, axis=axis, norm=norm) if x.dtype in [xp.float32, xp.complex64]: @@ -139,9 +141,9 @@ def ihfft( /, xp: Namespace, *, - n: Optional[int] = None, + n: int | None = None, axis: int = -1, - norm: Literal["backward", "ortho", "forward"] = "backward", + norm: _Norm = "backward", ) -> Array: res = xp.fft.ihfft(x, n=n, axis=axis, norm=norm) if x.dtype in [xp.float32, xp.complex64]: @@ -154,8 +156,8 @@ def fftfreq( xp: Namespace, *, d: float = 1.0, - dtype: Optional[DType] = None, - device: Optional[Device] = None, + dtype: DType | None = None, + device: Device | None = None, ) -> Array: if device not in ["cpu", None]: raise ValueError(f"Unsupported device {device!r}") @@ -170,8 +172,8 @@ def rfftfreq( xp: Namespace, *, d: float = 1.0, - dtype: Optional[DType] = None, - device: Optional[Device] = None, + dtype: DType | None = None, + device: Device | None = None, ) -> Array: if device not in ["cpu", None]: raise ValueError(f"Unsupported device {device!r}") @@ -181,12 +183,12 @@ def rfftfreq( return res def fftshift( - x: Array, /, xp: Namespace, *, axes: Union[int, Sequence[int]] = None + x: Array, /, xp: Namespace, *, axes: int | Sequence[int] | None = None ) -> Array: return xp.fft.fftshift(x, axes=axes) def ifftshift( - x: Array, /, xp: Namespace, *, axes: Union[int, Sequence[int]] = None + x: Array, /, xp: Namespace, *, axes: int | Sequence[int] | None = None ) -> Array: return xp.fft.ifftshift(x, axes=axes) @@ -206,3 +208,6 @@ def ifftshift( "fftshift", "ifftshift", ] + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 67c619b8..db3e4cd7 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -5,33 +5,82 @@ that are in __all__ are intended as additional helper functions for use by end users of the compat library. """ + from __future__ import annotations -import sys -import math import inspect +import math +import sys import warnings -from typing import Optional, Union, Any +from collections.abc import Collection +from typing import ( + TYPE_CHECKING, + Any, + Final, + Literal, + SupportsIndex, + TypeAlias, + TypeGuard, + TypeVar, + cast, + overload, +) + +from ._typing import Array, Device, HasShape, Namespace, SupportsArrayNamespace + +if TYPE_CHECKING: + + import dask.array as da + import jax + import ndonnx as ndx + import numpy as np + import numpy.typing as npt + import sparse # pyright: ignore[reportMissingTypeStubs] + import torch + + # TODO: import from typing (requires Python >=3.13) + from typing_extensions import TypeIs, TypeVar -from ._typing import Array, Device, Namespace + _SizeT = TypeVar("_SizeT", bound = int | None) + _ZeroGradientArray: TypeAlias = npt.NDArray[np.void] + _CupyArray: TypeAlias = Any # cupy has no py.typed -def _is_jax_zero_gradient_array(x: object) -> bool: + _ArrayApiObj: TypeAlias = ( + npt.NDArray[Any] + | da.Array + | jax.Array + | ndx.Array + | sparse.SparseArray + | torch.Tensor + | SupportsArrayNamespace[Any] + | _CupyArray + ) + +_API_VERSIONS_OLD: Final = frozenset({"2021.12", "2022.12", "2023.12"}) +_API_VERSIONS: Final = _API_VERSIONS_OLD | frozenset({"2024.12"}) + + +def _is_jax_zero_gradient_array(x: object) -> TypeGuard[_ZeroGradientArray]: """Return True if `x` is a zero-gradient array. These arrays are a design quirk of Jax that may one day be removed. See https://github.com/google/jax/issues/20620. """ - if 'numpy' not in sys.modules or 'jax' not in sys.modules: + if "numpy" not in sys.modules or "jax" not in sys.modules: return False - import numpy as np import jax + import numpy as np - return isinstance(x, np.ndarray) and x.dtype == jax.float0 + jax_float0 = cast("np.dtype[np.void]", jax.float0) + return ( + isinstance(x, np.ndarray) + and cast("npt.NDArray[np.void]", x).dtype == jax_float0 + ) -def is_numpy_array(x: object) -> bool: +def is_numpy_array(x: object) -> TypeGuard[npt.NDArray[Any]]: """ Return True if `x` is a NumPy array. @@ -53,14 +102,14 @@ def is_numpy_array(x: object) -> bool: is_pydata_sparse_array """ # Avoid importing NumPy if it isn't already - if 'numpy' not in sys.modules: + if "numpy" not in sys.modules: return False import numpy as np # TODO: Should we reject ndarray subclasses? return (isinstance(x, (np.ndarray, np.generic)) - and not _is_jax_zero_gradient_array(x)) + and not _is_jax_zero_gradient_array(x)) # pyright: ignore[reportUnknownArgumentType] # fmt: skip def is_cupy_array(x: object) -> bool: @@ -85,16 +134,16 @@ def is_cupy_array(x: object) -> bool: is_pydata_sparse_array """ # Avoid importing CuPy if it isn't already - if 'cupy' not in sys.modules: + if "cupy" not in sys.modules: return False - import cupy as cp + import cupy as cp # pyright: ignore[reportMissingTypeStubs] # TODO: Should we reject ndarray subclasses? - return isinstance(x, cp.ndarray) + return isinstance(x, cp.ndarray) # pyright: ignore[reportUnknownMemberType] -def is_torch_array(x: object) -> bool: +def is_torch_array(x: object) -> TypeIs[torch.Tensor]: """ Return True if `x` is a PyTorch tensor. @@ -113,7 +162,7 @@ def is_torch_array(x: object) -> bool: is_pydata_sparse_array """ # Avoid importing torch if it isn't already - if 'torch' not in sys.modules: + if "torch" not in sys.modules: return False import torch @@ -122,7 +171,7 @@ def is_torch_array(x: object) -> bool: return isinstance(x, torch.Tensor) -def is_ndonnx_array(x: object) -> bool: +def is_ndonnx_array(x: object) -> TypeIs[ndx.Array]: """ Return True if `x` is a ndonnx Array. @@ -142,7 +191,7 @@ def is_ndonnx_array(x: object) -> bool: is_pydata_sparse_array """ # Avoid importing torch if it isn't already - if 'ndonnx' not in sys.modules: + if "ndonnx" not in sys.modules: return False import ndonnx as ndx @@ -150,7 +199,7 @@ def is_ndonnx_array(x: object) -> bool: return isinstance(x, ndx.Array) -def is_dask_array(x: object) -> bool: +def is_dask_array(x: object) -> TypeIs[da.Array]: """ Return True if `x` is a dask.array Array. @@ -170,7 +219,7 @@ def is_dask_array(x: object) -> bool: is_pydata_sparse_array """ # Avoid importing dask if it isn't already - if 'dask.array' not in sys.modules: + if "dask.array" not in sys.modules: return False import dask.array @@ -178,7 +227,7 @@ def is_dask_array(x: object) -> bool: return isinstance(x, dask.array.Array) -def is_jax_array(x: object) -> bool: +def is_jax_array(x: object) -> TypeIs[jax.Array]: """ Return True if `x` is a JAX array. @@ -199,7 +248,7 @@ def is_jax_array(x: object) -> bool: is_pydata_sparse_array """ # Avoid importing jax if it isn't already - if 'jax' not in sys.modules: + if "jax" not in sys.modules: return False import jax @@ -207,7 +256,7 @@ def is_jax_array(x: object) -> bool: return isinstance(x, jax.Array) or _is_jax_zero_gradient_array(x) -def is_pydata_sparse_array(x) -> bool: +def is_pydata_sparse_array(x: object) -> TypeIs[sparse.SparseArray]: """ Return True if `x` is an array from the `sparse` package. @@ -228,16 +277,16 @@ def is_pydata_sparse_array(x) -> bool: is_jax_array """ # Avoid importing jax if it isn't already - if 'sparse' not in sys.modules: + if "sparse" not in sys.modules: return False - import sparse + import sparse # pyright: ignore[reportMissingTypeStubs] # TODO: Account for other backends. return isinstance(x, sparse.SparseArray) -def is_array_api_obj(x: object) -> bool: +def is_array_api_obj(x: object) -> TypeIs[_ArrayApiObj]: # pyright: ignore[reportUnknownParameterType] """ Return True if `x` is an array API compatible array object. @@ -252,18 +301,20 @@ def is_array_api_obj(x: object) -> bool: is_dask_array is_jax_array """ - return is_numpy_array(x) \ - or is_cupy_array(x) \ - or is_torch_array(x) \ - or is_dask_array(x) \ - or is_jax_array(x) \ - or is_pydata_sparse_array(x) \ - or hasattr(x, '__array_namespace__') + return ( + is_numpy_array(x) + or is_cupy_array(x) + or is_torch_array(x) + or is_dask_array(x) + or is_jax_array(x) + or is_pydata_sparse_array(x) + or hasattr(x, "__array_namespace__") + ) def _compat_module_name() -> str: - assert __name__.endswith('.common._helpers') - return __name__.removesuffix('.common._helpers') + assert __name__.endswith(".common._helpers") + return __name__.removesuffix(".common._helpers") def is_numpy_namespace(xp: Namespace) -> bool: @@ -284,7 +335,7 @@ def is_numpy_namespace(xp: Namespace) -> bool: is_pydata_sparse_namespace is_array_api_strict_namespace """ - return xp.__name__ in {'numpy', _compat_module_name() + '.numpy'} + return xp.__name__ in {"numpy", _compat_module_name() + ".numpy"} def is_cupy_namespace(xp: Namespace) -> bool: @@ -305,7 +356,7 @@ def is_cupy_namespace(xp: Namespace) -> bool: is_pydata_sparse_namespace is_array_api_strict_namespace """ - return xp.__name__ in {'cupy', _compat_module_name() + '.cupy'} + return xp.__name__ in {"cupy", _compat_module_name() + ".cupy"} def is_torch_namespace(xp: Namespace) -> bool: @@ -326,7 +377,7 @@ def is_torch_namespace(xp: Namespace) -> bool: is_pydata_sparse_namespace is_array_api_strict_namespace """ - return xp.__name__ in {'torch', _compat_module_name() + '.torch'} + return xp.__name__ in {"torch", _compat_module_name() + ".torch"} def is_ndonnx_namespace(xp: Namespace) -> bool: @@ -345,7 +396,7 @@ def is_ndonnx_namespace(xp: Namespace) -> bool: is_pydata_sparse_namespace is_array_api_strict_namespace """ - return xp.__name__ == 'ndonnx' + return xp.__name__ == "ndonnx" def is_dask_namespace(xp: Namespace) -> bool: @@ -366,7 +417,7 @@ def is_dask_namespace(xp: Namespace) -> bool: is_pydata_sparse_namespace is_array_api_strict_namespace """ - return xp.__name__ in {'dask.array', _compat_module_name() + '.dask.array'} + return xp.__name__ in {"dask.array", _compat_module_name() + ".dask.array"} def is_jax_namespace(xp: Namespace) -> bool: @@ -388,7 +439,7 @@ def is_jax_namespace(xp: Namespace) -> bool: is_pydata_sparse_namespace is_array_api_strict_namespace """ - return xp.__name__ in {'jax.numpy', 'jax.experimental.array_api'} + return xp.__name__ in {"jax.numpy", "jax.experimental.array_api"} def is_pydata_sparse_namespace(xp: Namespace) -> bool: @@ -407,7 +458,7 @@ def is_pydata_sparse_namespace(xp: Namespace) -> bool: is_jax_namespace is_array_api_strict_namespace """ - return xp.__name__ == 'sparse' + return xp.__name__ == "sparse" def is_array_api_strict_namespace(xp: Namespace) -> bool: @@ -426,21 +477,24 @@ def is_array_api_strict_namespace(xp: Namespace) -> bool: is_jax_namespace is_pydata_sparse_namespace """ - return xp.__name__ == 'array_api_strict' + return xp.__name__ == "array_api_strict" -def _check_api_version(api_version: str) -> None: - if api_version in ['2021.12', '2022.12', '2023.12']: - warnings.warn(f"The {api_version} version of the array API specification was requested but the returned namespace is actually version 2024.12") - elif api_version is not None and api_version not in ['2021.12', '2022.12', - '2023.12', '2024.12']: - raise ValueError("Only the 2024.12 version of the array API specification is currently supported") +def _check_api_version(api_version: str | None) -> None: + if api_version in _API_VERSIONS_OLD: + warnings.warn( + f"The {api_version} version of the array API specification was requested but the returned namespace is actually version 2024.12" + ) + elif api_version is not None and api_version not in _API_VERSIONS: + raise ValueError( + "Only the 2024.12 version of the array API specification is currently supported" + ) def array_namespace( - *xs: Union[Array, bool, int, float, complex, None], - api_version: Optional[str] = None, - use_compat: Optional[bool] = None, + *xs: Array | complex | None, + api_version: str | None = None, + use_compat: bool | None = None, ) -> Namespace: """ Get the array API compatible namespace for the arrays `xs`. @@ -510,11 +564,13 @@ def your_function(x, y): _use_compat = use_compat in [None, True] - namespaces = set() + namespaces: set[Namespace] = set() for x in xs: if is_numpy_array(x): - from .. import numpy as numpy_namespace import numpy as np + + from .. import numpy as numpy_namespace + if use_compat is True: _check_api_version(api_version) namespaces.add(numpy_namespace) @@ -528,25 +584,31 @@ def your_function(x, y): if _use_compat: _check_api_version(api_version) from .. import cupy as cupy_namespace + namespaces.add(cupy_namespace) else: - import cupy as cp + import cupy as cp # pyright: ignore[reportMissingTypeStubs] + namespaces.add(cp) elif is_torch_array(x): if _use_compat: _check_api_version(api_version) from .. import torch as torch_namespace + namespaces.add(torch_namespace) else: import torch + namespaces.add(torch) elif is_dask_array(x): if _use_compat: _check_api_version(api_version) from ..dask import array as dask_namespace + namespaces.add(dask_namespace) else: import dask.array as da + namespaces.add(da) elif is_jax_array(x): if use_compat is True: @@ -558,23 +620,27 @@ def your_function(x, y): # JAX v0.4.32 and newer implements the array API directly in jax.numpy. # For older JAX versions, it is available via jax.experimental.array_api. import jax.numpy + if hasattr(jax.numpy, "__array_api_version__"): jnp = jax.numpy else: - import jax.experimental.array_api as jnp + import jax.experimental.array_api as jnp # pyright: ignore[reportMissingImports] namespaces.add(jnp) elif is_pydata_sparse_array(x): if use_compat is True: _check_api_version(api_version) raise ValueError("`sparse` does not have an array-api-compat wrapper") else: - import sparse + import sparse # pyright: ignore[reportMissingTypeStubs] # `sparse` is already an array namespace. We do not have a wrapper # submodule for it. namespaces.add(sparse) - elif hasattr(x, '__array_namespace__'): + elif hasattr(x, "__array_namespace__"): if use_compat is True: - raise ValueError("The given array does not have an array-api-compat wrapper") + raise ValueError( + "The given array does not have an array-api-compat wrapper" + ) + x = cast("SupportsArrayNamespace[Any]", x) namespaces.add(x.__array_namespace__(api_version=api_version)) elif isinstance(x, (bool, int, float, complex, type(None))): continue @@ -588,15 +654,16 @@ def your_function(x, y): if len(namespaces) != 1: raise TypeError(f"Multiple namespaces for array inputs: {namespaces}") - xp, = namespaces + (xp,) = namespaces return xp + # backwards compatibility alias get_namespace = array_namespace -def _check_device(bare_xp, device): +def _check_device(bare_xp: Namespace, device: Device) -> None: # pyright: ignore[reportUnusedFunction] """ Validate dummy device on device-less array backends. @@ -609,11 +676,11 @@ def _check_device(bare_xp, device): https://github.com/data-apis/array-api-compat/pull/293 """ - if bare_xp is sys.modules.get('numpy'): + if bare_xp is sys.modules.get("numpy"): if device not in ("cpu", None): raise ValueError(f"Unsupported device for NumPy: {device!r}") - elif bare_xp is sys.modules.get('dask.array'): + elif bare_xp is sys.modules.get("dask.array"): if device not in ("cpu", _DASK_DEVICE, None): raise ValueError(f"Unsupported device for Dask: {device!r}") @@ -622,18 +689,20 @@ def _check_device(bare_xp, device): # when the array backend is not the CPU. # (since it is not easy to tell which device a dask array is on) class _dask_device: - def __repr__(self): + def __repr__(self) -> Literal["DASK_DEVICE"]: return "DASK_DEVICE" + _DASK_DEVICE = _dask_device() + # device() is not on numpy.ndarray or dask.array and to_device() is not on numpy.ndarray # or cupy.ndarray. They are not included in array objects of this library # because this library just reuses the respective ndarray classes without # wrapping or subclassing them. These helper functions can be used instead of # the wrapper functions for libraries that need to support both NumPy/CuPy and # other libraries that use devices. -def device(x: Array, /) -> Device: +def device(x: _ArrayApiObj, /) -> Device: """ Hardware device the array data resides on. @@ -669,7 +738,7 @@ def device(x: Array, /) -> Device: return "cpu" elif is_dask_array(x): # Peek at the metadata of the Dask array to determine type - if is_numpy_array(x._meta): + if is_numpy_array(x._meta): # pyright: ignore # Must be on CPU since backed by numpy return "cpu" return _DASK_DEVICE @@ -679,7 +748,7 @@ def device(x: Array, /) -> Device: # Return None in this case. Note that this workaround breaks # the standard and will result in new arrays being created on the # default device instead of the same device as the input array(s). - x_device = getattr(x, 'device', None) + x_device = getattr(x, "device", None) # Older JAX releases had .device() as a method, which has been replaced # with a property in accordance with the standard. if inspect.ismethod(x_device): @@ -688,27 +757,34 @@ def device(x: Array, /) -> Device: return x_device elif is_pydata_sparse_array(x): # `sparse` will gain `.device`, so check for this first. - x_device = getattr(x, 'device', None) + x_device = getattr(x, "device", None) if x_device is not None: return x_device # Everything but DOK has this attr. try: - inner = x.data + inner = x.data # pyright: ignore except AttributeError: return "cpu" # Return the device of the constituent array - return device(inner) - return x.device + return device(inner) # pyright: ignore + return x.device # pyright: ignore + # Prevent shadowing, used below _device = device + # Based on cupy.array_api.Array.to_device -def _cupy_to_device(x, device, /, stream=None): - import cupy as cp - from cupy.cuda import Device as _Device - from cupy.cuda import stream as stream_module - from cupy_backends.cuda.api import runtime +def _cupy_to_device( + x: _CupyArray, + device: Device, + /, + stream: int | Any | None = None, +) -> _CupyArray: + import cupy as cp # pyright: ignore[reportMissingTypeStubs] + from cupy.cuda import Device as _Device # pyright: ignore + from cupy.cuda import stream as stream_module # pyright: ignore + from cupy_backends.cuda.api import runtime # pyright: ignore if device == x.device: return x @@ -721,33 +797,40 @@ def _cupy_to_device(x, device, /, stream=None): raise ValueError(f"Unsupported device {device!r}") else: # see cupy/cupy#5985 for the reason how we handle device/stream here - prev_device = runtime.getDevice() - prev_stream: stream_module.Stream = None + prev_device: Any = runtime.getDevice() # pyright: ignore[reportUnknownMemberType] + prev_stream = None if stream is not None: - prev_stream = stream_module.get_current_stream() + prev_stream: Any = stream_module.get_current_stream() # pyright: ignore # stream can be an int as specified in __dlpack__, or a CuPy stream if isinstance(stream, int): - stream = cp.cuda.ExternalStream(stream) - elif isinstance(stream, cp.cuda.Stream): + stream = cp.cuda.ExternalStream(stream) # pyright: ignore + elif isinstance(stream, cp.cuda.Stream): # pyright: ignore[reportUnknownMemberType] pass else: - raise ValueError('the input stream is not recognized') - stream.use() + raise ValueError("the input stream is not recognized") + stream.use() # pyright: ignore[reportUnknownMemberType] try: - runtime.setDevice(device.id) + runtime.setDevice(device.id) # pyright: ignore[reportUnknownMemberType] arr = x.copy() finally: - runtime.setDevice(prev_device) + runtime.setDevice(prev_device) # pyright: ignore[reportUnknownMemberType] if stream is not None: prev_stream.use() return arr -def _torch_to_device(x, device, /, stream=None): + +def _torch_to_device( + x: torch.Tensor, + device: torch.device | str | int, + /, + stream: None = None, +) -> torch.Tensor: if stream is not None: raise NotImplementedError return x.to(device) -def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] = None) -> Array: + +def to_device(x: Array, device: Device, /, *, stream: int | Any | None = None) -> Array: """ Copy the array from the device on which it currently resides to the specified ``device``. @@ -767,7 +850,7 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] a ``device`` object (see the `Device Support `__ section of the array API specification). - stream: Optional[Union[int, Any]] + stream: int | Any | None stream object to use during copy. In addition to the types supported in ``array.__dlpack__``, implementations may choose to support any library-specific stream object with the caveat that any code using @@ -799,25 +882,26 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] if is_numpy_array(x): if stream is not None: raise ValueError("The stream argument to to_device() is not supported") - if device == 'cpu': + if device == "cpu": return x raise ValueError(f"Unsupported device {device!r}") elif is_cupy_array(x): # cupy does not yet have to_device return _cupy_to_device(x, device, stream=stream) elif is_torch_array(x): - return _torch_to_device(x, device, stream=stream) + return _torch_to_device(x, device, stream=stream) # pyright: ignore[reportArgumentType] elif is_dask_array(x): if stream is not None: raise ValueError("The stream argument to to_device() is not supported") # TODO: What if our array is on the GPU already? - if device == 'cpu': + if device == "cpu": return x raise ValueError(f"Unsupported device {device!r}") elif is_jax_array(x): if not hasattr(x, "__array_namespace__"): # In JAX v0.4.31 and older, this import adds to_device method to x... - import jax.experimental.array_api # noqa: F401 + import jax.experimental.array_api # noqa: F401 # pyright: ignore + # ... but only on eager JAX. It won't work inside jax.jit. if not hasattr(x, "to_device"): return x @@ -826,10 +910,16 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] # Perform trivial check to return the same array if # device is same instead of err-ing. return x - return x.to_device(device, stream=stream) + return x.to_device(device, stream=stream) # pyright: ignore -def size(x: Array) -> int | None: +@overload +def size(x: HasShape[Collection[SupportsIndex]]) -> int: ... +@overload +def size(x: HasShape[Collection[None]]) -> None: ... +@overload +def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None: ... +def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None: """ Return the total number of elements of x. @@ -844,7 +934,7 @@ def size(x: Array) -> int | None: # Lazy API compliant arrays, such as ndonnx, can contain None in their shape if None in x.shape: return None - out = math.prod(x.shape) + out = math.prod(cast("Collection[SupportsIndex]", x.shape)) # dask.array.Array.shape can contain NaN return None if math.isnan(out) else out @@ -907,7 +997,7 @@ def is_lazy_array(x: object) -> bool: # on __bool__ (dask is one such example, which however is special-cased above). # Select a single point of the array - s = size(x) + s = size(cast("HasShape[Collection[SupportsIndex | None]]", x)) if s is None: return True xp = array_namespace(x) @@ -952,4 +1042,7 @@ def is_lazy_array(x: object) -> bool: "to_device", ] -_all_ignore = ['sys', 'math', 'inspect', 'warnings'] +_all_ignore = ["sys", "math", "inspect", "warnings"] + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/common/_linalg.py b/array_api_compat/common/_linalg.py index d1e7ebd8..7e002aed 100644 --- a/array_api_compat/common/_linalg.py +++ b/array_api_compat/common/_linalg.py @@ -1,23 +1,33 @@ from __future__ import annotations import math -from typing import Literal, NamedTuple, Optional, Tuple, Union +from typing import Literal, NamedTuple, cast import numpy as np + if np.__version__[0] == "2": from numpy.lib.array_utils import normalize_axis_tuple else: from numpy.core.numeric import normalize_axis_tuple -from ._aliases import matmul, matrix_transpose, tensordot, vecdot, isdtype from .._internal import get_xp -from ._typing import Array, Namespace +from ._aliases import isdtype, matmul, matrix_transpose, tensordot, vecdot +from ._typing import Array, DType, Namespace + # These are in the main NumPy namespace but not in numpy.linalg -def cross(x1: Array, x2: Array, /, xp: Namespace, *, axis: int = -1, **kwargs) -> Array: +def cross( + x1: Array, + x2: Array, + /, + xp: Namespace, + *, + axis: int = -1, + **kwargs: object, +) -> Array: return xp.cross(x1, x2, axis=axis, **kwargs) -def outer(x1: Array, x2: Array, /, xp: Namespace, **kwargs) -> Array: +def outer(x1: Array, x2: Array, /, xp: Namespace, **kwargs: object) -> Array: return xp.outer(x1, x2, **kwargs) class EighResult(NamedTuple): @@ -39,46 +49,66 @@ class SVDResult(NamedTuple): # These functions are the same as their NumPy counterparts except they return # a namedtuple. -def eigh(x: Array, /, xp: Namespace, **kwargs) -> EighResult: +def eigh(x: Array, /, xp: Namespace, **kwargs: object) -> EighResult: return EighResult(*xp.linalg.eigh(x, **kwargs)) -def qr(x: Array, /, xp: Namespace, *, mode: Literal['reduced', 'complete'] = 'reduced', - **kwargs) -> QRResult: +def qr( + x: Array, + /, + xp: Namespace, + *, + mode: Literal["reduced", "complete"] = "reduced", + **kwargs: object, +) -> QRResult: return QRResult(*xp.linalg.qr(x, mode=mode, **kwargs)) -def slogdet(x: Array, /, xp: Namespace, **kwargs) -> SlogdetResult: +def slogdet(x: Array, /, xp: Namespace, **kwargs: object) -> SlogdetResult: return SlogdetResult(*xp.linalg.slogdet(x, **kwargs)) def svd( - x: Array, /, xp: Namespace, *, full_matrices: bool = True, **kwargs + x: Array, + /, + xp: Namespace, + *, + full_matrices: bool = True, + **kwargs: object, ) -> SVDResult: return SVDResult(*xp.linalg.svd(x, full_matrices=full_matrices, **kwargs)) # These functions have additional keyword arguments # The upper keyword argument is new from NumPy -def cholesky(x: Array, /, xp: Namespace, *, upper: bool = False, **kwargs) -> Array: +def cholesky( + x: Array, + /, + xp: Namespace, + *, + upper: bool = False, + **kwargs: object, +) -> Array: L = xp.linalg.cholesky(x, **kwargs) if upper: U = get_xp(xp)(matrix_transpose)(L) if get_xp(xp)(isdtype)(U.dtype, 'complex floating'): - U = xp.conj(U) + U = xp.conj(U) # pyright: ignore[reportConstantRedefinition] return U return L # The rtol keyword argument of matrix_rank() and pinv() is new from NumPy. # Note that it has a different semantic meaning from tol and rcond. -def matrix_rank(x: Array, - /, - xp: Namespace, - *, - rtol: Optional[Union[float, Array]] = None, - **kwargs) -> Array: +def matrix_rank( + x: Array, + /, + xp: Namespace, + *, + rtol: float | Array | None = None, + **kwargs: object, +) -> Array: # this is different from xp.linalg.matrix_rank, which supports 1 # dimensional arrays. if x.ndim < 2: raise xp.linalg.LinAlgError("1-dimensional array given. Array must be at least two-dimensional") - S = get_xp(xp)(svdvals)(x, **kwargs) + S: Array = get_xp(xp)(svdvals)(x, **kwargs) if rtol is None: tol = S.max(axis=-1, keepdims=True) * max(x.shape[-2:]) * xp.finfo(S.dtype).eps else: @@ -88,7 +118,12 @@ def matrix_rank(x: Array, return xp.count_nonzero(S > tol, axis=-1) def pinv( - x: Array, /, xp: Namespace, *, rtol: Optional[Union[float, Array]] = None, **kwargs + x: Array, + /, + xp: Namespace, + *, + rtol: float | Array | None = None, + **kwargs: object, ) -> Array: # this is different from xp.linalg.pinv, which does not multiply the # default tolerance by max(M, N). @@ -104,13 +139,13 @@ def matrix_norm( xp: Namespace, *, keepdims: bool = False, - ord: Optional[Union[int, float, Literal['fro', 'nuc']]] = 'fro', + ord: float | Literal["fro", "nuc"] | None = "fro", ) -> Array: return xp.linalg.norm(x, axis=(-2, -1), keepdims=keepdims, ord=ord) # svdvals is not in NumPy (but it is in SciPy). It is equivalent to # xp.linalg.svd(compute_uv=False). -def svdvals(x: Array, /, xp: Namespace) -> Union[Array, Tuple[Array, ...]]: +def svdvals(x: Array, /, xp: Namespace) -> Array | tuple[Array, ...]: return xp.linalg.svd(x, compute_uv=False) def vector_norm( @@ -118,9 +153,9 @@ def vector_norm( /, xp: Namespace, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, + axis: int | tuple[int, ...] | None = None, keepdims: bool = False, - ord: Optional[Union[int, float]] = 2, + ord: float = 2, ) -> Array: # xp.linalg.norm tries to do a matrix norm whenever axis is a 2-tuple or # when axis=None and the input is 2-D, so to force a vector norm, we make @@ -133,7 +168,10 @@ def vector_norm( elif isinstance(axis, tuple): # Note: The axis argument supports any number of axes, whereas # xp.linalg.norm() only supports a single axis for vector norm. - normalized_axis = normalize_axis_tuple(axis, x.ndim) + normalized_axis = cast( + "tuple[int, ...]", + normalize_axis_tuple(axis, x.ndim), # pyright: ignore[reportCallIssue] + ) rest = tuple(i for i in range(x.ndim) if i not in normalized_axis) newshape = axis + rest _x = xp.transpose(x, newshape).reshape( @@ -149,7 +187,13 @@ def vector_norm( # We can't reuse xp.linalg.norm(keepdims) because of the reshape hacks # above to avoid matrix norm logic. shape = list(x.shape) - _axis = normalize_axis_tuple(range(x.ndim) if axis is None else axis, x.ndim) + _axis = cast( + "tuple[int, ...]", + normalize_axis_tuple( # pyright: ignore[reportCallIssue] + range(x.ndim) if axis is None else axis, + x.ndim, + ), + ) for i in _axis: shape[i] = 1 res = xp.reshape(res, tuple(shape)) @@ -159,11 +203,17 @@ def vector_norm( # xp.diagonal and xp.trace operate on the first two axes whereas these # operates on the last two -def diagonal(x: Array, /, xp: Namespace, *, offset: int = 0, **kwargs) -> Array: +def diagonal(x: Array, /, xp: Namespace, *, offset: int = 0, **kwargs: object) -> Array: return xp.diagonal(x, offset=offset, axis1=-2, axis2=-1, **kwargs) def trace( - x: Array, /, xp: Namespace, *, offset: int = 0, dtype=None, **kwargs + x: Array, + /, + xp: Namespace, + *, + offset: int = 0, + dtype: DType | None = None, + **kwargs: object, ) -> Array: return xp.asarray( xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs) @@ -176,3 +226,7 @@ def trace( 'trace'] _all_ignore = ['math', 'normalize_axis_tuple', 'get_xp', 'np', 'isdtype'] + + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/common/_typing.py b/array_api_compat/common/_typing.py index 4c3b356b..d7deade1 100644 --- a/array_api_compat/common/_typing.py +++ b/array_api_compat/common/_typing.py @@ -1,24 +1,150 @@ from __future__ import annotations + +from collections.abc import Mapping from types import ModuleType as Namespace -from typing import Any, TypeVar, Protocol +from typing import TYPE_CHECKING, Literal, Protocol, TypeAlias, TypedDict, TypeVar + +if TYPE_CHECKING: + from _typeshed import Incomplete + + SupportsBufferProtocol: TypeAlias = Incomplete + Array: TypeAlias = Incomplete + Device: TypeAlias = Incomplete + DType: TypeAlias = Incomplete +else: + SupportsBufferProtocol = object + Array = object + Device = object + DType = object + + +_T_co = TypeVar("_T_co", covariant=True) + + +class NestedSequence(Protocol[_T_co]): + def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ... + def __len__(self, /) -> int: ... + + +class SupportsArrayNamespace(Protocol[_T_co]): + def __array_namespace__(self, /, *, api_version: str | None) -> _T_co: ... + + +class HasShape(Protocol[_T_co]): + @property + def shape(self, /) -> _T_co: ... + + +# Return type of `__array_namespace_info__.default_dtypes` +Capabilities = TypedDict( + "Capabilities", + { + "boolean indexing": bool, + "data-dependent shapes": bool, + "max dimensions": int, + }, +) + +# Return type of `__array_namespace_info__.default_dtypes` +DefaultDTypes = TypedDict( + "DefaultDTypes", + { + "real floating": DType, + "complex floating": DType, + "integral": DType, + "indexing": DType, + }, +) + + +_DTypeKind: TypeAlias = Literal[ + "bool", + "signed integer", + "unsigned integer", + "integral", + "real floating", + "complex floating", + "numeric", +] +# Type of the `kind` parameter in `__array_namespace_info__.dtypes` +DTypeKind: TypeAlias = _DTypeKind | tuple[_DTypeKind, ...] + + +# `__array_namespace_info__.dtypes(kind="bool")` +class DTypesBool(TypedDict): + bool: DType + + +# `__array_namespace_info__.dtypes(kind="signed integer")` +class DTypesSigned(TypedDict): + int8: DType + int16: DType + int32: DType + int64: DType + + +# `__array_namespace_info__.dtypes(kind="unsigned integer")` +class DTypesUnsigned(TypedDict): + uint8: DType + uint16: DType + uint32: DType + uint64: DType + + +# `__array_namespace_info__.dtypes(kind="integral")` +class DTypesIntegral(DTypesSigned, DTypesUnsigned): + pass + + +# `__array_namespace_info__.dtypes(kind="real floating")` +class DTypesReal(TypedDict): + float32: DType + float64: DType + + +# `__array_namespace_info__.dtypes(kind="complex floating")` +class DTypesComplex(TypedDict): + complex64: DType + complex128: DType + + +# `__array_namespace_info__.dtypes(kind="numeric")` +class DTypesNumeric(DTypesIntegral, DTypesReal, DTypesComplex): + pass + + +# `__array_namespace_info__.dtypes(kind=None)` (default) +class DTypesAll(DTypesBool, DTypesNumeric): + pass + + +# `__array_namespace_info__.dtypes(kind=?)` (fallback) +DTypesAny: TypeAlias = Mapping[str, DType] + __all__ = [ "Array", + "Capabilities", "DType", + "DTypeKind", + "DTypesAny", + "DTypesAll", + "DTypesBool", + "DTypesNumeric", + "DTypesIntegral", + "DTypesSigned", + "DTypesUnsigned", + "DTypesReal", + "DTypesComplex", + "DefaultDTypes", "Device", + "HasShape", "Namespace", "NestedSequence", + "SupportsArrayNamespace", "SupportsBufferProtocol", ] -_T_co = TypeVar("_T_co", covariant=True) - -class NestedSequence(Protocol[_T_co]): - def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ... - def __len__(self, /) -> int: ... - -SupportsBufferProtocol = Any -Array = Any -Device = Any -DType = Any +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/dask/array/__init__.py b/array_api_compat/dask/array/__init__.py index bb649306..1e47b960 100644 --- a/array_api_compat/dask/array/__init__.py +++ b/array_api_compat/dask/array/__init__.py @@ -1,9 +1,11 @@ -from dask.array import * # noqa: F403 +from typing import Final + +from dask.array import * # noqa: F403 # These imports may overwrite names from the import * above. -from ._aliases import * # noqa: F403 +from ._aliases import * # noqa: F403 -__array_api_version__ = '2024.12' +__array_api_version__: Final = "2024.12" # See the comment in the numpy __init__.py __import__(__package__ + '.linalg') diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index e7ddde78..9687a9cd 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -1,28 +1,38 @@ +# pyright: reportPrivateUsage=false +# pyright: reportUnknownArgumentType=false +# pyright: reportUnknownMemberType=false +# pyright: reportUnknownVariableType=false + from __future__ import annotations -from typing import Callable, Optional, Union +from builtins import bool as py_bool +from collections.abc import Callable +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from typing_extensions import TypeIs +import dask.array as da import numpy as np +from numpy import bool_ as bool from numpy import ( - # dtypes - bool_ as bool, + can_cast, + complex64, + complex128, float32, float64, int8, int16, int32, int64, + result_type, uint8, uint16, uint32, uint64, - complex64, - complex128, - can_cast, - result_type, ) -import dask.array as da +from ..._internal import get_xp from ...common import _aliases, _helpers, array_namespace from ...common._typing import ( Array, @@ -31,7 +41,6 @@ NestedSequence, SupportsBufferProtocol, ) -from ..._internal import get_xp from ._info import __array_namespace_info__ isdtype = get_xp(np)(_aliases.isdtype) @@ -44,8 +53,8 @@ def astype( dtype: DType, /, *, - copy: bool = True, - device: Optional[Device] = None, + copy: py_bool = True, + device: Device | None = None, ) -> Array: """ Array API compatibility wrapper for astype(). @@ -69,14 +78,14 @@ def astype( # not pass stop/step as keyword arguments, which will cause # an error with dask def arange( - start: Union[int, float], + start: float, /, - stop: Optional[Union[int, float]] = None, - step: Union[int, float] = 1, + stop: float | None = None, + step: float = 1, *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs, + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object, ) -> Array: """ Array API compatibility wrapper for arange(). @@ -87,7 +96,7 @@ def arange( # TODO: respect device keyword? _helpers._check_device(da, device) - args = [start] + args: list[Any] = [start] if stop is not None: args.append(stop) else: @@ -137,18 +146,13 @@ def arange( # asarray also adds the copy keyword, which is not present in numpy 1.0. def asarray( - obj: ( - Array - | bool | int | float | complex - | NestedSequence[bool | int | float | complex] - | SupportsBufferProtocol - ), + obj: complex | NestedSequence[complex] | Array | SupportsBufferProtocol, /, *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - copy: Optional[bool] = None, - **kwargs, + dtype: DType | None = None, + device: Device | None = None, + copy: py_bool | None = None, + **kwargs: object, ) -> Array: """ Array API compatibility wrapper for asarray(). @@ -164,7 +168,7 @@ def asarray( if copy is False: raise ValueError("Unable to avoid copy when changing dtype") obj = obj.astype(dtype) - return obj.copy() if copy else obj + return obj.copy() if copy else obj # pyright: ignore[reportAttributeAccessIssue] if copy is False: raise NotImplementedError( @@ -177,22 +181,21 @@ def asarray( return da.from_array(obj) -from dask.array import ( - # Element wise aliases - arccos as acos, - arccosh as acosh, - arcsin as asin, - arcsinh as asinh, - arctan as atan, - arctan2 as atan2, - arctanh as atanh, - left_shift as bitwise_left_shift, - right_shift as bitwise_right_shift, - invert as bitwise_invert, - power as pow, - # Other - concatenate as concat, -) +# Element wise aliases +from dask.array import arccos as acos +from dask.array import arccosh as acosh +from dask.array import arcsin as asin +from dask.array import arcsinh as asinh +from dask.array import arctan as atan +from dask.array import arctan2 as atan2 +from dask.array import arctanh as atanh + +# Other +from dask.array import concatenate as concat +from dask.array import invert as bitwise_invert +from dask.array import left_shift as bitwise_left_shift +from dask.array import power as pow +from dask.array import right_shift as bitwise_right_shift # dask.array.clip does not work unless all three arguments are provided. @@ -202,8 +205,8 @@ def asarray( def clip( x: Array, /, - min: Optional[Union[int, float, Array]] = None, - max: Optional[Union[int, float, Array]] = None, + min: float | Array | None = None, + max: float | Array | None = None, ) -> Array: """ Array API compatibility wrapper for clip(). @@ -212,8 +215,8 @@ def clip( specification for more details. """ - def _isscalar(a): - return isinstance(a, (int, float, type(None))) + def _isscalar(a: float | Array | None, /) -> TypeIs[float | None]: + return a is None or isinstance(a, (int, float)) min_shape = () if _isscalar(min) else min.shape max_shape = () if _isscalar(max) else max.shape @@ -266,7 +269,12 @@ def _ensure_single_chunk(x: Array, axis: int) -> tuple[Array, Callable[[Array], def sort( - x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True + x: Array, + /, + *, + axis: int = -1, + descending: py_bool = False, + stable: py_bool = True, ) -> Array: """ Array API compatibility layer around the lack of sort() in Dask. @@ -296,7 +304,12 @@ def sort( def argsort( - x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True + x: Array, + /, + *, + axis: int = -1, + descending: py_bool = False, + stable: py_bool = True, ) -> Array: """ Array API compatibility layer around the lack of argsort() in Dask. @@ -330,25 +343,34 @@ def argsort( # dask.array.count_nonzero does not have keepdims def count_nonzero( x: Array, - axis=None, - keepdims=False + axis: int | None = None, + keepdims: py_bool = False, ) -> Array: - result = da.count_nonzero(x, axis) - if keepdims: - if axis is None: - return da.reshape(result, [1]*x.ndim) - return da.expand_dims(result, axis) - return result - - + result = da.count_nonzero(x, axis) + if keepdims: + if axis is None: + return da.reshape(result, [1] * x.ndim) + return da.expand_dims(result, axis) + return result + + +__all__ = [ + "__array_namespace_info__", + "count_nonzero", + "bool", + "int8", "int16", "int32", "int64", + "uint8", "uint16", "uint32", "uint64", + "float32", "float64", + "complex64", "complex128", + "asarray", "astype", "can_cast", "result_type", + "pow", + "concat", + "acos", "acosh", "asin", "asinh", "atan", "atan2", "atanh", + "bitwise_left_shift", "bitwise_right_shift", "bitwise_invert", +] # fmt: skip +__all__ += _aliases.__all__ +_all_ignore = ["array_namespace", "get_xp", "da", "np"] -__all__ = _aliases.__all__ + [ - '__array_namespace_info__', 'asarray', 'astype', 'acos', - 'acosh', 'asin', 'asinh', 'atan', 'atan2', - 'atanh', 'bitwise_left_shift', 'bitwise_invert', - 'bitwise_right_shift', 'concat', 'pow', 'can_cast', - 'result_type', 'bool', 'float32', 'float64', 'int8', 'int16', 'int32', 'int64', - 'uint8', 'uint16', 'uint32', 'uint64', 'complex64', 'complex128', - 'can_cast', 'count_nonzero', 'result_type'] -_all_ignore = ["array_namespace", "get_xp", "da", "np"] +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/dask/array/_info.py b/array_api_compat/dask/array/_info.py index 614f43d9..9e4d736f 100644 --- a/array_api_compat/dask/array/_info.py +++ b/array_api_compat/dask/array/_info.py @@ -7,25 +7,51 @@ more details. """ + +# pyright: reportPrivateUsage=false + +from __future__ import annotations + +from typing import Literal as L +from typing import TypeAlias, overload + +from numpy import bool_ as bool from numpy import ( + complex64, + complex128, dtype, - bool_ as bool, - intp, + float32, + float64, int8, int16, int32, int64, + intp, uint8, uint16, uint32, uint64, - float32, - float64, - complex64, - complex128, ) -from ...common._helpers import _DASK_DEVICE +from ...common._helpers import _DASK_DEVICE, _dask_device +from ...common._typing import ( + Capabilities, + DefaultDTypes, + DType, + DTypeKind, + DTypesAll, + DTypesAny, + DTypesBool, + DTypesComplex, + DTypesIntegral, + DTypesNumeric, + DTypesReal, + DTypesSigned, + DTypesUnsigned, +) + +_Device: TypeAlias = L["cpu"] | _dask_device + class __array_namespace_info__: """ @@ -59,9 +85,9 @@ class __array_namespace_info__: """ - __module__ = 'dask.array' + __module__ = "dask.array" - def capabilities(self): + def capabilities(self) -> Capabilities: """ Return a dictionary of array API library capabilities. @@ -116,7 +142,7 @@ def capabilities(self): "max dimensions": 64, } - def default_device(self): + def default_device(self) -> L["cpu"]: """ The default device used for new Dask arrays. @@ -143,7 +169,7 @@ def default_device(self): """ return "cpu" - def default_dtypes(self, *, device=None): + def default_dtypes(self, /, *, device: _Device | None = None) -> DefaultDTypes: """ The default data types used for new Dask arrays. @@ -184,8 +210,8 @@ def default_dtypes(self, *, device=None): """ if device not in ["cpu", _DASK_DEVICE, None]: raise ValueError( - 'Device not understood. Only "cpu" or _DASK_DEVICE is allowed, but received:' - f' {device}' + f'Device not understood. Only "cpu" or _DASK_DEVICE is allowed, ' + f"but received: {device!r}" ) return { "real floating": dtype(float64), @@ -194,7 +220,41 @@ def default_dtypes(self, *, device=None): "indexing": dtype(intp), } - def dtypes(self, *, device=None, kind=None): + @overload + def dtypes( + self, /, *, device: _Device | None = None, kind: None = None + ) -> DTypesAll: ... + @overload + def dtypes( + self, /, *, device: _Device | None = None, kind: L["bool"] + ) -> DTypesBool: ... + @overload + def dtypes( + self, /, *, device: _Device | None = None, kind: L["signed integer"] + ) -> DTypesSigned: ... + @overload + def dtypes( + self, /, *, device: _Device | None = None, kind: L["unsigned integer"] + ) -> DTypesUnsigned: ... + @overload + def dtypes( + self, /, *, device: _Device | None = None, kind: L["integral"] + ) -> DTypesIntegral: ... + @overload + def dtypes( + self, /, *, device: _Device | None = None, kind: L["real floating"] + ) -> DTypesReal: ... + @overload + def dtypes( + self, /, *, device: _Device | None = None, kind: L["complex floating"] + ) -> DTypesComplex: ... + @overload + def dtypes( + self, /, *, device: _Device | None = None, kind: L["numeric"] + ) -> DTypesNumeric: ... + def dtypes( + self, /, *, device: _Device | None = None, kind: DTypeKind | None = None + ) -> DTypesAny: """ The array API data types supported by Dask. @@ -251,7 +311,7 @@ def dtypes(self, *, device=None, kind=None): if device not in ["cpu", _DASK_DEVICE, None]: raise ValueError( 'Device not understood. Only "cpu" or _DASK_DEVICE is allowed, but received:' - f' {device}' + f" {device}" ) if kind is None: return { @@ -321,14 +381,14 @@ def dtypes(self, *, device=None, kind=None): "complex64": dtype(complex64), "complex128": dtype(complex128), } - if isinstance(kind, tuple): - res = {} + if isinstance(kind, tuple): # type: ignore[reportUnnecessaryIsinstanceCall] + res: dict[str, DType] = {} for k in kind: res.update(self.dtypes(kind=k)) return res raise ValueError(f"unsupported kind: {kind!r}") - def devices(self): + def devices(self) -> list[_Device]: """ The devices supported by Dask. diff --git a/array_api_compat/dask/array/linalg.py b/array_api_compat/dask/array/linalg.py index bd53f0df..0825386e 100644 --- a/array_api_compat/dask/array/linalg.py +++ b/array_api_compat/dask/array/linalg.py @@ -3,15 +3,16 @@ from typing import Literal import dask.array as da + +# The `matmul` and `tensordot` functions are in both the main and linalg namespaces +from dask.array import matmul, outer, tensordot + # Exports -from dask.array.linalg import * # noqa: F403 -from dask.array import outer -# These functions are in both the main and linalg namespaces -from dask.array import matmul, tensordot +from dask.array.linalg import * # noqa: F403 from ..._internal import get_xp from ...common import _linalg -from ...common._typing import Array +from ...common._typing import Array as _Array from ._aliases import matrix_transpose, vecdot # dask.array.linalg doesn't have __all__. If it is added, replace this with @@ -32,8 +33,11 @@ # supports the mode keyword on QR # https://github.com/dask/dask/issues/10388 #qr = get_xp(da)(_linalg.qr) -def qr(x: Array, mode: Literal['reduced', 'complete'] = 'reduced', - **kwargs) -> QRResult: +def qr( + x: _Array, + mode: Literal["reduced", "complete"] = "reduced", + **kwargs: object, +) -> QRResult: if mode != "reduced": raise ValueError("dask arrays only support using mode='reduced'") return QRResult(*da.linalg.qr(x, **kwargs)) @@ -46,12 +50,12 @@ def qr(x: Array, mode: Literal['reduced', 'complete'] = 'reduced', # Wrap the svd functions to not pass full_matrices to dask # when full_matrices=False (as that is the default behavior for dask), # and dask doesn't have the full_matrices keyword -def svd(x: Array, full_matrices: bool = True, **kwargs) -> SVDResult: +def svd(x: _Array, full_matrices: bool = True, **kwargs) -> SVDResult: if full_matrices: raise ValueError("full_matrics=True is not supported by dask.") return da.linalg.svd(x, coerce_signs=False, **kwargs) -def svdvals(x: Array) -> Array: +def svdvals(x: _Array) -> _Array: # TODO: can't avoid computing U or V for dask _, s, _ = svd(x) return s diff --git a/array_api_compat/numpy/__init__.py b/array_api_compat/numpy/__init__.py index 6a5d9867..f7b558ba 100644 --- a/array_api_compat/numpy/__init__.py +++ b/array_api_compat/numpy/__init__.py @@ -1,10 +1,16 @@ -from numpy import * # noqa: F403 +# ruff: noqa: PLC0414 +from typing import Final + +from numpy import * # noqa: F403 # pyright: ignore[reportWildcardImportFromLibrary] # from numpy import * doesn't overwrite these builtin names -from numpy import abs, max, min, round # noqa: F401 +from numpy import abs as abs +from numpy import max as max +from numpy import min as min +from numpy import round as round # These imports may overwrite names from the import * above. -from ._aliases import * # noqa: F403 +from ._aliases import * # noqa: F403 # Don't know why, but we have to do an absolute import to import linalg. If we # instead do @@ -13,9 +19,17 @@ # # It doesn't overwrite np.linalg from above. The import is generated # dynamically so that the library can be vendored. -__import__(__package__ + '.linalg') -__import__(__package__ + '.fft') +__import__(__package__ + ".linalg") + +__import__(__package__ + ".fft") + +from ..common._helpers import * # noqa: F403 +from .linalg import matrix_transpose, vecdot # noqa: F401 -from .linalg import matrix_transpose, vecdot # noqa: F401 +try: + # Used in asarray(). Not present in older versions. + from numpy import _CopyMode # noqa: F401 +except ImportError: + pass -__array_api_version__ = '2024.12' +__array_api_version__: Final = "2024.12" diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index d1fd46a1..d8792611 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -1,6 +1,10 @@ +# pyright: reportPrivateUsage=false from __future__ import annotations -from typing import Optional, Union +from builtins import bool as py_bool +from typing import TYPE_CHECKING, Any, Literal, TypeAlias, cast + +import numpy as np from .._internal import get_xp from ..common import _aliases, _helpers @@ -8,7 +12,12 @@ from ._info import __array_namespace_info__ from ._typing import Array, Device, DType -import numpy as np +if TYPE_CHECKING: + from typing_extensions import Buffer, TypeIs + +# The values of the `_CopyMode` enum can be either `False`, `True`, or `2`: +# https://github.com/numpy/numpy/blob/5a8a6a79d9c2fff8f07dcab5d41e14f8508d673f/numpy/_globals.pyi#L7-L10 +_Copy: TypeAlias = py_bool | Literal[2] | np._CopyMode bool = np.bool_ @@ -65,9 +74,9 @@ iinfo = get_xp(np)(_aliases.iinfo) -def _supports_buffer_protocol(obj): +def _supports_buffer_protocol(obj: object) -> TypeIs[Buffer]: # pyright: ignore[reportUnusedFunction] try: - memoryview(obj) + memoryview(obj) # pyright: ignore[reportArgumentType] except TypeError: return False return True @@ -78,18 +87,13 @@ def _supports_buffer_protocol(obj): # complicated enough that it's easier to define it separately for each module # rather than trying to combine everything into one function in common/ def asarray( - obj: ( - Array - | bool | int | float | complex - | NestedSequence[bool | int | float | complex] - | SupportsBufferProtocol - ), + obj: Array | complex | NestedSequence[complex] | SupportsBufferProtocol, /, *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - copy: Optional[Union[bool, np._CopyMode]] = None, - **kwargs, + dtype: DType | None = None, + device: Device | None = None, + copy: _Copy | None = None, + **kwargs: Any, ) -> Array: """ Array API compatibility wrapper for asarray(). @@ -106,7 +110,7 @@ def asarray( elif copy is True: copy = np._CopyMode.ALWAYS - return np.array(obj, copy=copy, dtype=dtype, **kwargs) + return np.array(obj, copy=copy, dtype=dtype, **kwargs) # pyright: ignore def astype( @@ -114,8 +118,8 @@ def astype( dtype: DType, /, *, - copy: bool = True, - device: Optional[Device] = None, + copy: py_bool = True, + device: Device | None = None, ) -> Array: _helpers._check_device(np, device) return x.astype(dtype=dtype, copy=copy) @@ -123,8 +127,14 @@ def astype( # count_nonzero returns a python int for axis=None and keepdims=False # https://github.com/numpy/numpy/issues/17562 -def count_nonzero(x: Array, axis=None, keepdims=False) -> Array: - result = np.count_nonzero(x, axis=axis, keepdims=keepdims) +def count_nonzero( + x: Array, + axis: int | tuple[int, ...] | None = None, + keepdims: py_bool = False, +) -> Array: + # NOTE: this is currently incorrectly typed in numpy, but will be fixed in + # numpy 2.2.5 and 2.3.0: https://github.com/numpy/numpy/pull/28750 + result = cast("Any", np.count_nonzero(x, axis=axis, keepdims=keepdims)) # pyright: ignore[reportArgumentType, reportCallIssue] if axis is None and not keepdims: return np.asarray(result) return result @@ -132,25 +142,43 @@ def count_nonzero(x: Array, axis=None, keepdims=False) -> Array: # These functions are completely new here. If the library already has them # (i.e., numpy 2.0), use the library version instead of our wrapper. -if hasattr(np, 'vecdot'): +if hasattr(np, "vecdot"): vecdot = np.vecdot else: vecdot = get_xp(np)(_aliases.vecdot) -if hasattr(np, 'isdtype'): +if hasattr(np, "isdtype"): isdtype = np.isdtype else: isdtype = get_xp(np)(_aliases.isdtype) -if hasattr(np, 'unstack'): +if hasattr(np, "unstack"): unstack = np.unstack else: unstack = get_xp(np)(_aliases.unstack) -__all__ = _aliases.__all__ + ['__array_namespace_info__', 'asarray', 'astype', - 'acos', 'acosh', 'asin', 'asinh', 'atan', - 'atan2', 'atanh', 'bitwise_left_shift', - 'bitwise_invert', 'bitwise_right_shift', - 'bool', 'concat', 'count_nonzero', 'pow'] - -_all_ignore = ['np', 'get_xp'] +__all__ = [ + "__array_namespace_info__", + "asarray", + "astype", + "acos", + "acosh", + "asin", + "asinh", + "atan", + "atan2", + "atanh", + "bitwise_left_shift", + "bitwise_invert", + "bitwise_right_shift", + "bool", + "concat", + "count_nonzero", + "pow", +] +__all__ += _aliases.__all__ +_all_ignore = ["np", "get_xp"] + + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/numpy/_info.py b/array_api_compat/numpy/_info.py index 365855b8..f307f62c 100644 --- a/array_api_compat/numpy/_info.py +++ b/array_api_compat/numpy/_info.py @@ -7,24 +7,28 @@ more details. """ +from __future__ import annotations + +from numpy import bool_ as bool from numpy import ( + complex64, + complex128, dtype, - bool_ as bool, - intp, + float32, + float64, int8, int16, int32, int64, + intp, uint8, uint16, uint32, uint64, - float32, - float64, - complex64, - complex128, ) +from ._typing import Device, DType + class __array_namespace_info__: """ @@ -131,7 +135,11 @@ def default_device(self): """ return "cpu" - def default_dtypes(self, *, device=None): + def default_dtypes( + self, + *, + device: Device | None = None, + ) -> dict[str, dtype[intp | float64 | complex128]]: """ The default data types used for new NumPy arrays. @@ -183,7 +191,12 @@ def default_dtypes(self, *, device=None): "indexing": dtype(intp), } - def dtypes(self, *, device=None, kind=None): + def dtypes( + self, + *, + device: Device | None = None, + kind: str | tuple[str, ...] | None = None, + ) -> dict[str, DType]: """ The array API data types supported by NumPy. @@ -260,7 +273,7 @@ def dtypes(self, *, device=None, kind=None): "complex128": dtype(complex128), } if kind == "bool": - return {"bool": bool} + return {"bool": dtype(bool)} if kind == "signed integer": return { "int8": dtype(int8), @@ -312,13 +325,13 @@ def dtypes(self, *, device=None, kind=None): "complex128": dtype(complex128), } if isinstance(kind, tuple): - res = {} + res: dict[str, DType] = {} for k in kind: res.update(self.dtypes(kind=k)) return res raise ValueError(f"unsupported kind: {kind!r}") - def devices(self): + def devices(self) -> list[Device]: """ The devices supported by NumPy. @@ -344,3 +357,10 @@ def devices(self): """ return ["cpu"] + + +__all__ = ["__array_namespace_info__"] + + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/numpy/_typing.py b/array_api_compat/numpy/_typing.py index a6c96924..e771c788 100644 --- a/array_api_compat/numpy/_typing.py +++ b/array_api_compat/numpy/_typing.py @@ -1,31 +1,30 @@ from __future__ import annotations -__all__ = ["Array", "DType", "Device"] -_all_ignore = ["np"] - -from typing import Literal, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Literal, TypeAlias import numpy as np -from numpy import ndarray as Array -Device = Literal["cpu"] +Device: TypeAlias = Literal["cpu"] + if TYPE_CHECKING: + # NumPy 1.x on Python 3.10 fails to parse np.dtype[] - DType = np.dtype[ - np.intp - | np.int8 - | np.int16 - | np.int32 - | np.int64 - | np.uint8 - | np.uint16 - | np.uint32 - | np.uint64 + DType: TypeAlias = np.dtype[ + np.bool_ + | np.integer[Any] | np.float32 | np.float64 | np.complex64 | np.complex128 - | np.bool ] + Array: TypeAlias = np.ndarray[Any, DType] else: - DType = np.dtype + DType: TypeAlias = np.dtype + Array: TypeAlias = np.ndarray + +__all__ = ["Array", "DType", "Device"] +_all_ignore = ["np"] + + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/numpy/fft.py b/array_api_compat/numpy/fft.py index 28667594..06875f00 100644 --- a/array_api_compat/numpy/fft.py +++ b/array_api_compat/numpy/fft.py @@ -1,10 +1,9 @@ -from numpy.fft import * # noqa: F403 +import numpy as np from numpy.fft import __all__ as fft_all +from numpy.fft import fft2, ifft2, irfft2, rfft2 -from ..common import _fft from .._internal import get_xp - -import numpy as np +from ..common import _fft fft = get_xp(np)(_fft.fft) ifft = get_xp(np)(_fft.ifft) @@ -21,7 +20,14 @@ fftshift = get_xp(np)(_fft.fftshift) ifftshift = get_xp(np)(_fft.ifftshift) -__all__ = fft_all + _fft.__all__ + +__all__ = ["rfft2", "irfft2", "fft2", "ifft2"] +__all__ += _fft.__all__ + + +def __dir__() -> list[str]: + return __all__ + del get_xp del np diff --git a/array_api_compat/numpy/linalg.py b/array_api_compat/numpy/linalg.py index 8f01593b..2d3e731d 100644 --- a/array_api_compat/numpy/linalg.py +++ b/array_api_compat/numpy/linalg.py @@ -1,14 +1,35 @@ -from numpy.linalg import * # noqa: F403 -from numpy.linalg import __all__ as linalg_all -import numpy as _np +# pyright: reportAttributeAccessIssue=false +# pyright: reportUnknownArgumentType=false +# pyright: reportUnknownMemberType=false +# pyright: reportUnknownVariableType=false + +from __future__ import annotations + +import numpy as np + +# intersection of `np.linalg.__all__` on numpy 1.22 and 2.2, minus `_linalg.__all__` +from numpy.linalg import ( + LinAlgError, + cond, + det, + eig, + eigvals, + eigvalsh, + inv, + lstsq, + matrix_power, + multi_dot, + norm, + tensorinv, + tensorsolve, +) -from ..common import _linalg from .._internal import get_xp +from ..common import _linalg # These functions are in both the main and linalg namespaces -from ._aliases import matmul, matrix_transpose, tensordot, vecdot # noqa: F401 - -import numpy as np +from ._aliases import matmul, matrix_transpose, tensordot, vecdot # noqa: F401 +from ._typing import Array cross = get_xp(np)(_linalg.cross) outer = get_xp(np)(_linalg.outer) @@ -38,19 +59,28 @@ # To workaround this, the below is the code from np.linalg.solve except # only calling solve1 in the exactly 1D case. + # This code is here instead of in common because it is numpy specific. Also # note that CuPy's solve() does not currently support broadcasting (see # https://github.com/cupy/cupy/blob/main/cupy/cublas.py#L43). -def solve(x1: _np.ndarray, x2: _np.ndarray, /) -> _np.ndarray: +def solve(x1: Array, x2: Array, /) -> Array: try: from numpy.linalg._linalg import ( - _makearray, _assert_stacked_2d, _assert_stacked_square, - _commonType, isComplexType, _raise_linalgerror_singular + _assert_stacked_2d, + _assert_stacked_square, + _commonType, + _makearray, + _raise_linalgerror_singular, + isComplexType, ) except ImportError: from numpy.linalg.linalg import ( - _makearray, _assert_stacked_2d, _assert_stacked_square, - _commonType, isComplexType, _raise_linalgerror_singular + _assert_stacked_2d, + _assert_stacked_square, + _commonType, + _makearray, + _raise_linalgerror_singular, + isComplexType, ) from numpy.linalg import _umath_linalg @@ -61,6 +91,7 @@ def solve(x1: _np.ndarray, x2: _np.ndarray, /) -> _np.ndarray: t, result_t = _commonType(x1, x2) # This part is different from np.linalg.solve + gufunc: np.ufunc if x2.ndim == 1: gufunc = _umath_linalg.solve1 else: @@ -68,23 +99,45 @@ def solve(x1: _np.ndarray, x2: _np.ndarray, /) -> _np.ndarray: # This does nothing currently but is left in because it will be relevant # when complex dtype support is added to the spec in 2022. - signature = 'DD->D' if isComplexType(t) else 'dd->d' - with _np.errstate(call=_raise_linalgerror_singular, invalid='call', - over='ignore', divide='ignore', under='ignore'): - r = gufunc(x1, x2, signature=signature) + signature = "DD->D" if isComplexType(t) else "dd->d" + with np.errstate( + call=_raise_linalgerror_singular, + invalid="call", + over="ignore", + divide="ignore", + under="ignore", + ): + r: Array = gufunc(x1, x2, signature=signature) return wrap(r.astype(result_t, copy=False)) + # These functions are completely new here. If the library already has them # (i.e., numpy 2.0), use the library version instead of our wrapper. -if hasattr(np.linalg, 'vector_norm'): +if hasattr(np.linalg, "vector_norm"): vector_norm = np.linalg.vector_norm else: vector_norm = get_xp(np)(_linalg.vector_norm) -__all__ = linalg_all + _linalg.__all__ + ['solve'] -del get_xp -del np -del linalg_all -del _linalg +__all__ = [ + "LinAlgError", + "cond", + "det", + "eig", + "eigvals", + "eigvalsh", + "inv", + "lstsq", + "matrix_power", + "multi_dot", + "norm", + "tensorinv", + "tensorsolve", +] +__all__ += _linalg.__all__ +__all__ += ["solve", "vector_norm"] + + +def __dir__() -> list[str]: + return __all__ From 5e14b53a3558765a8f9b921c72f0249cc0c1c5b9 Mon Sep 17 00:00:00 2001 From: Joren Hammudoglu Date: Sat, 19 Apr 2025 16:08:41 +0200 Subject: [PATCH 035/151] TYP: reject `bool` in the `ord` params of `vector_norm` and `matrix_norm` (#310) * TYP: auto-plagiarize the optypean `Just*` types * TYP: reject `bool` in the `ord` params of `vector_norm` and `matrix_norm` * TYP: remove accidental type alias * TYP: Tighten the `ord` param of `matrix_norm` Co-authored-by: Lucas Colley --------- Co-authored-by: Lucas Colley --- array_api_compat/common/_linalg.py | 6 ++-- array_api_compat/common/_typing.py | 44 +++++++++++++++++++++++++++++- array_api_compat/torch/linalg.py | 8 ++++-- 3 files changed, 52 insertions(+), 6 deletions(-) diff --git a/array_api_compat/common/_linalg.py b/array_api_compat/common/_linalg.py index 7e002aed..7ad87a1b 100644 --- a/array_api_compat/common/_linalg.py +++ b/array_api_compat/common/_linalg.py @@ -12,7 +12,7 @@ from .._internal import get_xp from ._aliases import isdtype, matmul, matrix_transpose, tensordot, vecdot -from ._typing import Array, DType, Namespace +from ._typing import Array, DType, JustFloat, JustInt, Namespace # These are in the main NumPy namespace but not in numpy.linalg @@ -139,7 +139,7 @@ def matrix_norm( xp: Namespace, *, keepdims: bool = False, - ord: float | Literal["fro", "nuc"] | None = "fro", + ord: Literal[1, 2, -1, -2] | JustFloat | Literal["fro", "nuc"] | None = "fro", ) -> Array: return xp.linalg.norm(x, axis=(-2, -1), keepdims=keepdims, ord=ord) @@ -155,7 +155,7 @@ def vector_norm( *, axis: int | tuple[int, ...] | None = None, keepdims: bool = False, - ord: float = 2, + ord: JustInt | JustFloat = 2, ) -> Array: # xp.linalg.norm tries to do a matrix norm whenever axis is a 2-tuple or # when axis=None and the input is 2-D, so to force a vector norm, we make diff --git a/array_api_compat/common/_typing.py b/array_api_compat/common/_typing.py index d7deade1..cd26feeb 100644 --- a/array_api_compat/common/_typing.py +++ b/array_api_compat/common/_typing.py @@ -2,7 +2,15 @@ from collections.abc import Mapping from types import ModuleType as Namespace -from typing import TYPE_CHECKING, Literal, Protocol, TypeAlias, TypedDict, TypeVar +from typing import ( + TYPE_CHECKING, + Literal, + Protocol, + TypeAlias, + TypedDict, + TypeVar, + final, +) if TYPE_CHECKING: from _typeshed import Incomplete @@ -21,6 +29,37 @@ _T_co = TypeVar("_T_co", covariant=True) +# These "Just" types are equivalent to the `Just` type from the `optype` library, +# apart from them not being `@runtime_checkable`. +# - docs: https://github.com/jorenham/optype/blob/master/README.md#just +# - code: https://github.com/jorenham/optype/blob/master/optype/_core/_just.py +@final +class JustInt(Protocol): + @property + def __class__(self, /) -> type[int]: ... + @__class__.setter + def __class__(self, value: type[int], /) -> None: ... # pyright: ignore[reportIncompatibleMethodOverride] + + +@final +class JustFloat(Protocol): + @property + def __class__(self, /) -> type[float]: ... + @__class__.setter + def __class__(self, value: type[float], /) -> None: ... # pyright: ignore[reportIncompatibleMethodOverride] + + +@final +class JustComplex(Protocol): + @property + def __class__(self, /) -> type[complex]: ... + @__class__.setter + def __class__(self, value: type[complex], /) -> None: ... # pyright: ignore[reportIncompatibleMethodOverride] + + +# + + class NestedSequence(Protocol[_T_co]): def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ... def __len__(self, /) -> int: ... @@ -140,6 +179,9 @@ class DTypesAll(DTypesBool, DTypesNumeric): "Device", "HasShape", "Namespace", + "JustInt", + "JustFloat", + "JustComplex", "NestedSequence", "SupportsArrayNamespace", "SupportsBufferProtocol", diff --git a/array_api_compat/torch/linalg.py b/array_api_compat/torch/linalg.py index 1ff7319d..70d72405 100644 --- a/array_api_compat/torch/linalg.py +++ b/array_api_compat/torch/linalg.py @@ -16,6 +16,7 @@ # These functions are in both the main and linalg namespaces from ._aliases import matmul, matrix_transpose, tensordot from ._typing import Array, DType +from ..common._typing import JustInt, JustFloat # Note: torch.linalg.cross does not default to axis=-1 (it defaults to the # first axis with size 3), see https://github.com/pytorch/pytorch/issues/58743 @@ -84,8 +85,8 @@ def vector_norm( *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False, - # float stands for inf | -inf, which are not valid for Literal - ord: Union[int, float] = 2, + # JustFloat stands for inf | -inf, which are not valid for Literal + ord: JustInt | JustFloat = 2, **kwargs, ) -> Array: # torch.vector_norm incorrectly treats axis=() the same as axis=None @@ -115,3 +116,6 @@ def vector_norm( _all_ignore = ['torch_linalg', 'sum'] del linalg_all + +def __dir__() -> list[str]: + return __all__ From 52e01beae335c088d25bd6d76f5ae44a231800f5 Mon Sep 17 00:00:00 2001 From: Guido Imperiale Date: Mon, 21 Apr 2025 19:38:53 +0100 Subject: [PATCH 036/151] ENH: cache helper functions (#308) * ENH: cache helper functions --- array_api_compat/common/_helpers.py | 192 ++++++++++++++++------------ 1 file changed, 108 insertions(+), 84 deletions(-) diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index db3e4cd7..d50e0d83 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -12,7 +12,8 @@ import math import sys import warnings -from collections.abc import Collection +from collections.abc import Collection, Hashable +from functools import lru_cache from typing import ( TYPE_CHECKING, Any, @@ -61,23 +62,37 @@ _API_VERSIONS: Final = _API_VERSIONS_OLD | frozenset({"2024.12"}) +@lru_cache(100) +def _issubclass_fast(cls: type, modname: str, clsname: str) -> bool: + try: + mod = sys.modules[modname] + except KeyError: + return False + parent_cls = getattr(mod, clsname) + return issubclass(cls, parent_cls) + + def _is_jax_zero_gradient_array(x: object) -> TypeGuard[_ZeroGradientArray]: """Return True if `x` is a zero-gradient array. These arrays are a design quirk of Jax that may one day be removed. See https://github.com/google/jax/issues/20620. """ - if "numpy" not in sys.modules or "jax" not in sys.modules: + # Fast exit + try: + dtype = x.dtype # type: ignore[attr-defined] + except AttributeError: + return False + cls = cast(Hashable, type(dtype)) + if not _issubclass_fast(cls, "numpy.dtypes", "VoidDType"): return False - import jax - import numpy as np + if "jax" not in sys.modules: + return False - jax_float0 = cast("np.dtype[np.void]", jax.float0) - return ( - isinstance(x, np.ndarray) - and cast("npt.NDArray[np.void]", x).dtype == jax_float0 - ) + import jax + # jax.float0 is a np.dtype([('float0', 'V')]) + return dtype == jax.float0 def is_numpy_array(x: object) -> TypeGuard[npt.NDArray[Any]]: @@ -101,15 +116,12 @@ def is_numpy_array(x: object) -> TypeGuard[npt.NDArray[Any]]: is_jax_array is_pydata_sparse_array """ - # Avoid importing NumPy if it isn't already - if "numpy" not in sys.modules: - return False - - import numpy as np - # TODO: Should we reject ndarray subclasses? - return (isinstance(x, (np.ndarray, np.generic)) - and not _is_jax_zero_gradient_array(x)) # pyright: ignore[reportUnknownArgumentType] # fmt: skip + cls = cast(Hashable, type(x)) + return ( + _issubclass_fast(cls, "numpy", "ndarray") + or _issubclass_fast(cls, "numpy", "generic") + ) and not _is_jax_zero_gradient_array(x) def is_cupy_array(x: object) -> bool: @@ -133,14 +145,8 @@ def is_cupy_array(x: object) -> bool: is_jax_array is_pydata_sparse_array """ - # Avoid importing CuPy if it isn't already - if "cupy" not in sys.modules: - return False - - import cupy as cp # pyright: ignore[reportMissingTypeStubs] - - # TODO: Should we reject ndarray subclasses? - return isinstance(x, cp.ndarray) # pyright: ignore[reportUnknownMemberType] + cls = cast(Hashable, type(x)) + return _issubclass_fast(cls, "cupy", "ndarray") def is_torch_array(x: object) -> TypeIs[torch.Tensor]: @@ -161,14 +167,8 @@ def is_torch_array(x: object) -> TypeIs[torch.Tensor]: is_jax_array is_pydata_sparse_array """ - # Avoid importing torch if it isn't already - if "torch" not in sys.modules: - return False - - import torch - - # TODO: Should we reject ndarray subclasses? - return isinstance(x, torch.Tensor) + cls = cast(Hashable, type(x)) + return _issubclass_fast(cls, "torch", "Tensor") def is_ndonnx_array(x: object) -> TypeIs[ndx.Array]: @@ -190,13 +190,8 @@ def is_ndonnx_array(x: object) -> TypeIs[ndx.Array]: is_jax_array is_pydata_sparse_array """ - # Avoid importing torch if it isn't already - if "ndonnx" not in sys.modules: - return False - - import ndonnx as ndx - - return isinstance(x, ndx.Array) + cls = cast(Hashable, type(x)) + return _issubclass_fast(cls, "ndonnx", "Array") def is_dask_array(x: object) -> TypeIs[da.Array]: @@ -218,13 +213,8 @@ def is_dask_array(x: object) -> TypeIs[da.Array]: is_jax_array is_pydata_sparse_array """ - # Avoid importing dask if it isn't already - if "dask.array" not in sys.modules: - return False - - import dask.array - - return isinstance(x, dask.array.Array) + cls = cast(Hashable, type(x)) + return _issubclass_fast(cls, "dask.array", "Array") def is_jax_array(x: object) -> TypeIs[jax.Array]: @@ -247,13 +237,8 @@ def is_jax_array(x: object) -> TypeIs[jax.Array]: is_dask_array is_pydata_sparse_array """ - # Avoid importing jax if it isn't already - if "jax" not in sys.modules: - return False - - import jax - - return isinstance(x, jax.Array) or _is_jax_zero_gradient_array(x) + cls = cast(Hashable, type(x)) + return _issubclass_fast(cls, "jax", "Array") or _is_jax_zero_gradient_array(x) def is_pydata_sparse_array(x: object) -> TypeIs[sparse.SparseArray]: @@ -276,14 +261,9 @@ def is_pydata_sparse_array(x: object) -> TypeIs[sparse.SparseArray]: is_dask_array is_jax_array """ - # Avoid importing jax if it isn't already - if "sparse" not in sys.modules: - return False - - import sparse # pyright: ignore[reportMissingTypeStubs] - # TODO: Account for other backends. - return isinstance(x, sparse.SparseArray) + cls = cast(Hashable, type(x)) + return _issubclass_fast(cls, "sparse", "SparseArray") def is_array_api_obj(x: object) -> TypeIs[_ArrayApiObj]: # pyright: ignore[reportUnknownParameterType] @@ -302,13 +282,23 @@ def is_array_api_obj(x: object) -> TypeIs[_ArrayApiObj]: # pyright: ignore[repo is_jax_array """ return ( - is_numpy_array(x) - or is_cupy_array(x) - or is_torch_array(x) - or is_dask_array(x) - or is_jax_array(x) - or is_pydata_sparse_array(x) - or hasattr(x, "__array_namespace__") + hasattr(x, '__array_namespace__') + or _is_array_api_cls(cast(Hashable, type(x))) + ) + + +@lru_cache(100) +def _is_array_api_cls(cls: type) -> bool: + return ( + # TODO: drop support for numpy<2 which didn't have __array_namespace__ + _issubclass_fast(cls, "numpy", "ndarray") + or _issubclass_fast(cls, "numpy", "generic") + or _issubclass_fast(cls, "cupy", "ndarray") + or _issubclass_fast(cls, "torch", "Tensor") + or _issubclass_fast(cls, "dask.array", "Array") + or _issubclass_fast(cls, "sparse", "SparseArray") + # TODO: drop support for jax<0.4.32 which didn't have __array_namespace__ + or _issubclass_fast(cls, "jax", "Array") ) @@ -317,6 +307,7 @@ def _compat_module_name() -> str: return __name__.removesuffix(".common._helpers") +@lru_cache(100) def is_numpy_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is a NumPy namespace. @@ -338,6 +329,7 @@ def is_numpy_namespace(xp: Namespace) -> bool: return xp.__name__ in {"numpy", _compat_module_name() + ".numpy"} +@lru_cache(100) def is_cupy_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is a CuPy namespace. @@ -359,6 +351,7 @@ def is_cupy_namespace(xp: Namespace) -> bool: return xp.__name__ in {"cupy", _compat_module_name() + ".cupy"} +@lru_cache(100) def is_torch_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is a PyTorch namespace. @@ -399,6 +392,7 @@ def is_ndonnx_namespace(xp: Namespace) -> bool: return xp.__name__ == "ndonnx" +@lru_cache(100) def is_dask_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is a Dask namespace. @@ -939,6 +933,19 @@ def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None: return None if math.isnan(out) else out +@lru_cache(100) +def _is_writeable_cls(cls: type) -> bool | None: + if ( + _issubclass_fast(cls, "numpy", "generic") + or _issubclass_fast(cls, "jax", "Array") + or _issubclass_fast(cls, "sparse", "SparseArray") + ): + return False + if _is_array_api_cls(cls): + return True + return None + + def is_writeable_array(x: object) -> bool: """ Return False if ``x.__setitem__`` is expected to raise; True otherwise. @@ -949,11 +956,32 @@ def is_writeable_array(x: object) -> bool: As there is no standard way to check if an array is writeable without actually writing to it, this function blindly returns True for all unknown array types. """ - if is_numpy_array(x): - return x.flags.writeable - if is_jax_array(x) or is_pydata_sparse_array(x): + cls = cast(Hashable, type(x)) + if _issubclass_fast(cls, "numpy", "ndarray"): + return cast("npt.NDArray", x).flags.writeable + res = _is_writeable_cls(cls) + if res is not None: + return res + return hasattr(x, '__array_namespace__') + + +@lru_cache(100) +def _is_lazy_cls(cls: type) -> bool | None: + if ( + _issubclass_fast(cls, "numpy", "ndarray") + or _issubclass_fast(cls, "numpy", "generic") + or _issubclass_fast(cls, "cupy", "ndarray") + or _issubclass_fast(cls, "torch", "Tensor") + or _issubclass_fast(cls, "sparse", "SparseArray") + ): return False - return is_array_api_obj(x) + if ( + _issubclass_fast(cls, "jax", "Array") + or _issubclass_fast(cls, "dask.array", "Array") + or _issubclass_fast(cls, "ndonnx", "Array") + ): + return True + return None def is_lazy_array(x: object) -> bool: @@ -969,14 +997,6 @@ def is_lazy_array(x: object) -> bool: This function errs on the side of caution for array types that may or may not be lazy, e.g. JAX arrays, by always returning True for them. """ - if ( - is_numpy_array(x) - or is_cupy_array(x) - or is_torch_array(x) - or is_pydata_sparse_array(x) - ): - return False - # **JAX note:** while it is possible to determine if you're inside or outside # jax.jit by testing the subclass of a jax.Array object, as well as testing bool() # as we do below for unknown arrays, this is not recommended by JAX best practices. @@ -986,10 +1006,14 @@ def is_lazy_array(x: object) -> bool: # compatibility, is highly detrimental to performance as the whole graph will end # up being computed multiple times. - if is_jax_array(x) or is_dask_array(x) or is_ndonnx_array(x): - return True + # Note: skipping reclassification of JAX zero gradient arrays, as one will + # exclusively get them once they leave a jax.grad JIT context. + cls = cast(Hashable, type(x)) + res = _is_lazy_cls(cls) + if res is not None: + return res - if not is_array_api_obj(x): + if not hasattr(x, "__array_namespace__"): return False # Unknown Array API compatible object. Note that this test may have dire consequences @@ -1042,7 +1066,7 @@ def is_lazy_array(x: object) -> bool: "to_device", ] -_all_ignore = ["sys", "math", "inspect", "warnings"] +_all_ignore = ['lru_cache', 'sys', 'math', 'inspect', 'warnings'] def __dir__() -> list[str]: return __all__ From e600449a645c2e6ce5a2276da0006491f097c096 Mon Sep 17 00:00:00 2001 From: Guido Imperiale Date: Thu, 24 Apr 2025 10:09:45 +0100 Subject: [PATCH 037/151] ENH: Simplify CuPy `asarray` and `to_device` (#314) reviewed at https://github.com/data-apis/array-api-compat/pull/314 --- array_api_compat/common/_helpers.py | 48 ++++++++++------------------- array_api_compat/cupy/_aliases.py | 30 +++++------------- cupy-xfails.txt | 3 -- tests/test_common.py | 24 +++++++++------ tests/test_cupy.py | 22 +++++++++++++ 5 files changed, 61 insertions(+), 66 deletions(-) create mode 100644 tests/test_cupy.py diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index d50e0d83..77175d0d 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -775,42 +775,28 @@ def _cupy_to_device( /, stream: int | Any | None = None, ) -> _CupyArray: - import cupy as cp # pyright: ignore[reportMissingTypeStubs] - from cupy.cuda import Device as _Device # pyright: ignore - from cupy.cuda import stream as stream_module # pyright: ignore - from cupy_backends.cuda.api import runtime # pyright: ignore + import cupy as cp - if device == x.device: - return x - elif device == "cpu": + if device == "cpu": # allowing us to use `to_device(x, "cpu")` # is useful for portable test swapping between # host and device backends return x.get() - elif not isinstance(device, _Device): - raise ValueError(f"Unsupported device {device!r}") - else: - # see cupy/cupy#5985 for the reason how we handle device/stream here - prev_device: Any = runtime.getDevice() # pyright: ignore[reportUnknownMemberType] - prev_stream = None - if stream is not None: - prev_stream: Any = stream_module.get_current_stream() # pyright: ignore - # stream can be an int as specified in __dlpack__, or a CuPy stream - if isinstance(stream, int): - stream = cp.cuda.ExternalStream(stream) # pyright: ignore - elif isinstance(stream, cp.cuda.Stream): # pyright: ignore[reportUnknownMemberType] - pass - else: - raise ValueError("the input stream is not recognized") - stream.use() # pyright: ignore[reportUnknownMemberType] - try: - runtime.setDevice(device.id) # pyright: ignore[reportUnknownMemberType] - arr = x.copy() - finally: - runtime.setDevice(prev_device) # pyright: ignore[reportUnknownMemberType] - if stream is not None: - prev_stream.use() - return arr + if not isinstance(device, cp.cuda.Device): + raise TypeError(f"Unsupported device type {device!r}") + + if stream is None: + with device: + return cp.asarray(x) + + # stream can be an int as specified in __dlpack__, or a CuPy stream + if isinstance(stream, int): + stream = cp.cuda.ExternalStream(stream) + elif not isinstance(stream, cp.cuda.Stream): + raise TypeError(f"Unsupported stream type {stream!r}") + + with device, stream: + return cp.asarray(x) def _torch_to_device( diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index fd1460ae..adb74bff 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -64,8 +64,6 @@ finfo = get_xp(cp)(_aliases.finfo) iinfo = get_xp(cp)(_aliases.iinfo) -_copy_default = object() - # asarray also adds the copy keyword, which is not present in numpy 1.0. def asarray( @@ -79,7 +77,7 @@ def asarray( *, dtype: Optional[DType] = None, device: Optional[Device] = None, - copy: Optional[bool] = _copy_default, + copy: Optional[bool] = None, **kwargs, ) -> Array: """ @@ -89,25 +87,13 @@ def asarray( specification for more details. """ with cp.cuda.Device(device): - # cupy is like NumPy 1.26 (except without _CopyMode). See the comments - # in asarray in numpy/_aliases.py. - if copy is not _copy_default: - # A future version of CuPy will change the meaning of copy=False - # to mean no-copy. We don't know for certain what version it will - # be yet, so to avoid breaking that version, we use a different - # default value for copy so asarray(obj) with no copy kwarg will - # always do the copy-if-needed behavior. - - # This will still need to be updated to remove the - # NotImplementedError for copy=False, but at least this won't - # break the default or existing behavior. - if copy is None: - copy = False - elif copy is False: - raise NotImplementedError("asarray(copy=False) is not yet supported in cupy") - kwargs['copy'] = copy - - return cp.array(obj, dtype=dtype, **kwargs) + if copy is None: + return cp.asarray(obj, dtype=dtype, **kwargs) + else: + res = cp.array(obj, dtype=dtype, copy=copy, **kwargs) + if not copy and res is not obj: + raise ValueError("Unable to avoid copy while creating an array as requested") + return res def astype( diff --git a/cupy-xfails.txt b/cupy-xfails.txt index a30572f8..df85d9ca 100644 --- a/cupy-xfails.txt +++ b/cupy-xfails.txt @@ -11,9 +11,6 @@ array_api_tests/test_array_object.py::test_scalar_casting[__index__(int64)] # testsuite bug (https://github.com/data-apis/array-api-tests/issues/172) array_api_tests/test_array_object.py::test_getitem -# copy=False is not yet implemented -array_api_tests/test_creation_functions.py::test_asarray_arrays - # attributes are np.float32 instead of float # (see also https://github.com/data-apis/array-api/issues/405) array_api_tests/test_data_type_functions.py::test_finfo[float32] diff --git a/tests/test_common.py b/tests/test_common.py index 6b1aa160..d1933899 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -17,6 +17,7 @@ from array_api_compat import ( device, is_array_api_obj, is_lazy_array, is_writeable_array, size, to_device ) +from array_api_compat.common._helpers import _DASK_DEVICE from ._helpers import all_libraries, import_, wrapped_libraries, xfail @@ -189,23 +190,26 @@ class C: @pytest.mark.parametrize("library", all_libraries) -def test_device(library, request): +def test_device_to_device(library, request): if library == "ndonnx": - xfail(request, reason="Needs ndonnx >=0.9.4") + xfail(request, reason="Stub raises ValueError") + if library == "sparse": + xfail(request, reason="No __array_namespace_info__()") xp = import_(library, wrapper=True) + devices = xp.__array_namespace_info__().devices() - # We can't test much for device() and to_device() other than that - # x.to_device(x.device) works. - + # Default device x = xp.asarray([1, 2, 3]) dev = device(x) - x2 = to_device(x, dev) - assert device(x2) == device(x) - - x3 = xp.asarray(x, device=dev) - assert device(x3) == device(x) + for dev in devices: + if dev is None: # JAX >=0.5.3 + continue + if dev is _DASK_DEVICE: # TODO this needs a better design + continue + y = to_device(x, dev) + assert device(y) == dev @pytest.mark.parametrize("library", wrapped_libraries) diff --git a/tests/test_cupy.py b/tests/test_cupy.py new file mode 100644 index 00000000..f8b4a4d8 --- /dev/null +++ b/tests/test_cupy.py @@ -0,0 +1,22 @@ +import pytest +from array_api_compat import device, to_device + +xp = pytest.importorskip("array_api_compat.cupy") +from cupy.cuda import Stream + + +def test_to_device_with_stream(): + devices = xp.__array_namespace_info__().devices() + streams = [ + Stream(), + Stream(non_blocking=True), + Stream(null=True), + Stream(ptds=True), + 123, # dlpack stream + ] + + a = xp.asarray([1, 2, 3]) + for dev in devices: + for stream in streams: + b = to_device(a, dev, stream=stream) + assert device(b) == dev From 1acba0c1cd06bd26eb526bd08168f2c60f22f0b9 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Fri, 9 May 2025 13:13:46 +0200 Subject: [PATCH 038/151] BUG: take_along_axis: add numpy and cupy aliases, skip testing on dask (#317) --- array_api_compat/cupy/_aliases.py | 8 +++++++- array_api_compat/numpy/_aliases.py | 6 ++++++ dask-xfails.txt | 3 +++ 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index adb74bff..90b48f05 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -124,6 +124,11 @@ def count_nonzero( return result +# take_along_axis: axis defaults to -1 but in cupy (and numpy) axis is a required arg +def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1): + return cp.take_along_axis(x, indices, axis=axis) + + # These functions are completely new here. If the library already has them # (i.e., numpy 2.0), use the library version instead of our wrapper. if hasattr(cp, 'vecdot'): @@ -145,6 +150,7 @@ def count_nonzero( 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atan2', 'atanh', 'bitwise_left_shift', 'bitwise_invert', 'bitwise_right_shift', - 'bool', 'concat', 'count_nonzero', 'pow', 'sign'] + 'bool', 'concat', 'count_nonzero', 'pow', 'sign', + 'take_along_axis'] _all_ignore = ['cp', 'get_xp'] diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index d8792611..a1aee5c0 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -140,6 +140,11 @@ def count_nonzero( return result +# take_along_axis: axis defaults to -1 but in numpy axis is a required arg +def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1): + return np.take_along_axis(x, indices, axis=axis) + + # These functions are completely new here. If the library already has them # (i.e., numpy 2.0), use the library version instead of our wrapper. if hasattr(np, "vecdot"): @@ -175,6 +180,7 @@ def count_nonzero( "concat", "count_nonzero", "pow", + "take_along_axis" ] __all__ += _aliases.__all__ _all_ignore = ["np", "get_xp"] diff --git a/dask-xfails.txt b/dask-xfails.txt index 932aeada..3efb4f96 100644 --- a/dask-xfails.txt +++ b/dask-xfails.txt @@ -24,6 +24,9 @@ array_api_tests/test_creation_functions.py::test_linspace # Shape mismatch array_api_tests/test_indexing_functions.py::test_take +# missing `take_along_axis`, https://github.com/dask/dask/issues/3663 +array_api_tests/test_indexing_functions.py::test_take_along_axis + # Array methods and attributes not already on da.Array cannot be wrapped array_api_tests/test_has_names.py::test_has_names[array_method-__array_namespace__] array_api_tests/test_has_names.py::test_has_names[array_method-to_device] From ddbbc35ab2bebed4637f18e227d6a9138c0f7669 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Fri, 9 May 2025 16:19:59 +0200 Subject: [PATCH 039/151] TST: add a skip for CuPy pow(er) is not fully NEP50 compatible in CuPy 13.x --- cupy-xfails.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cupy-xfails.txt b/cupy-xfails.txt index a30572f8..55c6437d 100644 --- a/cupy-xfails.txt +++ b/cupy-xfails.txt @@ -37,6 +37,8 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__sub array_api_tests/test_operators_and_elementwise_functions.py::test_add[__add__(x, s)] # floating point inaccuracy array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[remainder(x1, x2)] +# incomplete NEP50 support in CuPy 13.x (fixed in 14.0.0a1) +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[pow] # cupy (arg)min/max wrong with infinities # https://github.com/cupy/cupy/issues/7424 From 7d7a85862b345e0247e20dd64dbe6d327098a869 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Fri, 9 May 2025 14:46:06 +0000 Subject: [PATCH 040/151] TST: update CuPy skips --- cupy-xfails.txt | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cupy-xfails.txt b/cupy-xfails.txt index 55c6437d..89e9af54 100644 --- a/cupy-xfails.txt +++ b/cupy-xfails.txt @@ -39,6 +39,9 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_add[__add__(x, array_api_tests/test_operators_and_elementwise_functions.py::test_remainder[remainder(x1, x2)] # incomplete NEP50 support in CuPy 13.x (fixed in 14.0.0a1) array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[pow] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[logaddexp] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[nextafter] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[hypot] # cupy (arg)min/max wrong with infinities # https://github.com/cupy/cupy/issues/7424 @@ -187,7 +190,6 @@ array_api_tests/test_signatures.py::test_func_signature[from_dlpack] array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__] # 2024.12 support -array_api_tests/test_signatures.py::test_func_signature[count_nonzero] array_api_tests/test_signatures.py::test_func_signature[bitwise_and] array_api_tests/test_signatures.py::test_func_signature[bitwise_left_shift] array_api_tests/test_signatures.py::test_func_signature[bitwise_or] From 4e3d809646653a919d9c494b8afd730291e441fb Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Fri, 9 May 2025 17:10:49 +0200 Subject: [PATCH 041/151] TST: add xfails for NumPy 1.22 and 1.26 / python scalars --- numpy-1-22-xfails.txt | 7 +++++++ numpy-1-26-xfails.txt | 2 ++ 2 files changed, 9 insertions(+) diff --git a/numpy-1-22-xfails.txt b/numpy-1-22-xfails.txt index 93edf311..c1de77d8 100644 --- a/numpy-1-22-xfails.txt +++ b/numpy-1-22-xfails.txt @@ -123,6 +123,13 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[subtr array_api_tests/test_searching_functions.py::test_where array_api_tests/test_special_cases.py::test_iop[__iadd__(x1_i is -0 and x2_i is -0) -> -0] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[add] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[divide] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[hypot] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[subtract] + +array_api_tests/test_operators_and_elementwise_functions.py::test_where_with_scalars + # 2023.12 support array_api_tests/test_has_names.py::test_has_names[creation-from_dlpack] array_api_tests/test_signatures.py::test_func_signature[from_dlpack] diff --git a/numpy-1-26-xfails.txt b/numpy-1-26-xfails.txt index 51e1a658..98cb9f6c 100644 --- a/numpy-1-26-xfails.txt +++ b/numpy-1-26-xfails.txt @@ -50,6 +50,8 @@ array_api_tests/test_signatures.py::test_func_signature[bitwise_right_shift] array_api_tests/test_signatures.py::test_func_signature[bitwise_xor] array_api_tests/test_data_type_functions.py::TestResultType::test_with_scalars +array_api_tests/test_operators_and_elementwise_functions.py::test_where_with_scalars + # Stubs have a comment: (**note**: libraries may return ``NaN`` to match Python behavior.); Apparently, NumPy does just that array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] From 9b8f252683bdd90090649b801dc31402c58fdc96 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Thu, 8 May 2025 22:49:09 +0200 Subject: [PATCH 042/151] BUG: torch: fix count_nonzero with axis tuple and keepdims --- array_api_compat/torch/_aliases.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 027a0261..335008e4 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -548,8 +548,12 @@ def count_nonzero( ) -> Array: result = torch.count_nonzero(x, dim=axis) if keepdims: - if axis is not None: + if isinstance(axis, int): return result.unsqueeze(axis) + elif isinstance(axis, tuple): + n_axis = [x.ndim + ax if ax < 0 else ax for ax in axis] + sh = [1 if i in n_axis else x.shape[i] for i in range(x.ndim)] + return torch.reshape(result, sh) return _axis_none_keepdims(result, x.ndim, keepdims) else: return result From 8c62443da64b2dee5fbf0623f9fd510e62577c45 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Fri, 9 May 2025 23:35:22 +0200 Subject: [PATCH 043/151] TST: update numpy 1.22 xfails --- numpy-1-22-xfails.txt | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/numpy-1-22-xfails.txt b/numpy-1-22-xfails.txt index c1de77d8..cacb95b7 100644 --- a/numpy-1-22-xfails.txt +++ b/numpy-1-22-xfails.txt @@ -127,6 +127,13 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_sc array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[divide] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[hypot] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[subtract] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[logaddexp] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[nextafter] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[multiply] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[maximum] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[copysign] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[pow] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[pow] array_api_tests/test_operators_and_elementwise_functions.py::test_where_with_scalars From 5597ec755d44cb005f01601b3c2193f9f56b604f Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sun, 11 May 2025 12:56:49 +0200 Subject: [PATCH 044/151] CI: use ARRAY_API_TESTS_XFAIL_MARK on CI --- .github/workflows/array-api-tests-dask.yml | 2 ++ .github/workflows/array-api-tests-numpy-1-22.yml | 2 ++ .github/workflows/array-api-tests-numpy-1-26.yml | 2 ++ .github/workflows/array-api-tests-numpy-dev.yml | 2 ++ .github/workflows/array-api-tests-numpy-latest.yml | 2 ++ .github/workflows/array-api-tests-torch.yml | 1 + 6 files changed, 11 insertions(+) diff --git a/.github/workflows/array-api-tests-dask.yml b/.github/workflows/array-api-tests-dask.yml index afc67975..964fb52d 100644 --- a/.github/workflows/array-api-tests-dask.yml +++ b/.github/workflows/array-api-tests-dask.yml @@ -16,3 +16,5 @@ jobs: # the full test suite with at least 200 examples. pytest-extra-args: --max-examples=5 python-versions: '[''3.10'', ''3.13'']' + extra-env-vars: | + ARRAY_API_TESTS_XFAIL_MARK=skip diff --git a/.github/workflows/array-api-tests-numpy-1-22.yml b/.github/workflows/array-api-tests-numpy-1-22.yml index d8f60432..1cf6e26d 100644 --- a/.github/workflows/array-api-tests-numpy-1-22.yml +++ b/.github/workflows/array-api-tests-numpy-1-22.yml @@ -10,3 +10,5 @@ jobs: package-version: '== 1.22.*' xfails-file-extra: '-1-22' python-versions: '[''3.10'']' + extra-env-vars: | + ARRAY_API_TESTS_XFAIL_MARK=skip diff --git a/.github/workflows/array-api-tests-numpy-1-26.yml b/.github/workflows/array-api-tests-numpy-1-26.yml index 33780760..a2788d2f 100644 --- a/.github/workflows/array-api-tests-numpy-1-26.yml +++ b/.github/workflows/array-api-tests-numpy-1-26.yml @@ -10,3 +10,5 @@ jobs: package-version: '== 1.26.*' xfails-file-extra: '-1-26' python-versions: '[''3.10'', ''3.12'']' + extra-env-vars: | + ARRAY_API_TESTS_XFAIL_MARK=skip diff --git a/.github/workflows/array-api-tests-numpy-dev.yml b/.github/workflows/array-api-tests-numpy-dev.yml index d6de1a53..dce0813f 100644 --- a/.github/workflows/array-api-tests-numpy-dev.yml +++ b/.github/workflows/array-api-tests-numpy-dev.yml @@ -10,3 +10,5 @@ jobs: extra-requires: '--pre --extra-index https://pypi.anaconda.org/scientific-python-nightly-wheels/simple' xfails-file-extra: '-dev' python-versions: '[''3.11'', ''3.13'']' + extra-env-vars: | + ARRAY_API_TESTS_XFAIL_MARK=skip diff --git a/.github/workflows/array-api-tests-numpy-latest.yml b/.github/workflows/array-api-tests-numpy-latest.yml index 4d3667f6..54b21a25 100644 --- a/.github/workflows/array-api-tests-numpy-latest.yml +++ b/.github/workflows/array-api-tests-numpy-latest.yml @@ -8,3 +8,5 @@ jobs: with: package-name: numpy python-versions: '[''3.10'', ''3.13'']' + extra-env-vars: | + ARRAY_API_TESTS_XFAIL_MARK=skip diff --git a/.github/workflows/array-api-tests-torch.yml b/.github/workflows/array-api-tests-torch.yml index ac20df25..4dcb3347 100644 --- a/.github/workflows/array-api-tests-torch.yml +++ b/.github/workflows/array-api-tests-torch.yml @@ -10,4 +10,5 @@ jobs: extra-requires: '--index-url https://download.pytorch.org/whl/cpu' extra-env-vars: | ARRAY_API_TESTS_SKIP_DTYPES=uint16,uint32,uint64 + ARRAY_API_TESTS_XFAIL_MARK=skip python-versions: '[''3.10'', ''3.13'']' From 5cf5d8f404b18ff67543762ed8e92cb0f359f885 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sun, 11 May 2025 21:58:35 +0200 Subject: [PATCH 045/151] BUG: torch: meshgrid defaults to indexing="xy" As of version 2.6, torch defaults to indexing='ij', and is planning to transition to 'xy' at some point. When it does, we'll be able to drop our wrapper. --- array_api_compat/torch/_aliases.py | 10 ++++++++-- tests/test_torch.py | 17 +++++++++++++++++ 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 335008e4..de5d1a5d 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -2,7 +2,7 @@ from functools import reduce as _reduce, wraps as _wraps from builtins import all as _builtin_all, any as _builtin_any -from typing import Any, List, Optional, Sequence, Tuple, Union +from typing import Any, List, Optional, Sequence, Tuple, Union, Literal import torch @@ -828,6 +828,12 @@ def sign(x: Array, /) -> Array: return out +def meshgrid(*arrays: Array, indexing: Literal['xy', 'ij'] = 'xy') -> List[Array]: + # enforce the default of 'xy' + # TODO: is the return type a list or a tuple + return list(torch.meshgrid(*arrays, indexing='xy')) + + __all__ = ['__array_namespace_info__', 'asarray', 'result_type', 'can_cast', 'permute_dims', 'bitwise_invert', 'newaxis', 'conj', 'add', 'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or', @@ -844,6 +850,6 @@ def sign(x: Array, /) -> Array: 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult', 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values', 'matmul', 'matrix_transpose', 'vecdot', 'tensordot', 'isdtype', - 'take', 'take_along_axis', 'sign', 'finfo', 'iinfo', 'repeat'] + 'take', 'take_along_axis', 'sign', 'finfo', 'iinfo', 'repeat', 'meshgrid'] _all_ignore = ['torch', 'get_xp'] diff --git a/tests/test_torch.py b/tests/test_torch.py index e8340f31..7adb4ab3 100644 --- a/tests/test_torch.py +++ b/tests/test_torch.py @@ -100,3 +100,20 @@ def test_gh_273(self, default_dt, dtype_a, dtype_b): assert dtype_1 == dtype_2 finally: torch.set_default_dtype(prev_default) + + +def test_meshgrid(): + """Verify that array_api_compat.torch.meshgrid defaults to indexing='xy'.""" + + x, y = xp.asarray([1, 2]), xp.asarray([4]) + + X, Y = xp.meshgrid(x, y) + + # output of torch.meshgrid(x, y, indexing='xy') -- indexing='ij' is different + X_xy, Y_xy = xp.asarray([[1, 2]]), xp.asarray([[4, 4]]) + + assert X.shape == X_xy.shape + assert xp.all(X == X_xy) + + assert Y.shape == Y_xy.shape + assert xp.all(Y == Y_xy) From 6488ad81748a1b92f7a0de42e5a10461a9df6b62 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 15 May 2025 08:55:20 +0100 Subject: [PATCH 046/151] TST: revisit test for `asarray` `copy=` parameter --- array_api_compat/dask/array/_aliases.py | 2 +- tests/test_common.py | 67 ++++++++++--------------- 2 files changed, 28 insertions(+), 41 deletions(-) diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index 9687a9cd..d43881ab 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -171,7 +171,7 @@ def asarray( return obj.copy() if copy else obj # pyright: ignore[reportAttributeAccessIssue] if copy is False: - raise NotImplementedError( + raise ValueError( "Unable to avoid copy when converting a non-dask object to dask" ) diff --git a/tests/test_common.py b/tests/test_common.py index d1933899..fe4fe598 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -278,87 +278,73 @@ def test_asarray_copy(library): xp = import_(library, wrapper=True) asarray = xp.asarray is_lib_func = globals()[is_array_functions[library]] - all = xp.all if library != 'dask.array' else lambda x: xp.all(x).compute() - - if library == 'cupy': - supports_copy_false_other_ns = False - supports_copy_false_same_ns = False - elif library == 'dask.array': - supports_copy_false_other_ns = False - supports_copy_false_same_ns = True - else: - supports_copy_false_other_ns = True - supports_copy_false_same_ns = True a = asarray([1]) b = asarray(a, copy=True) assert is_lib_func(b) a[0] = 0 - assert all(b[0] == 1) - assert all(a[0] == 0) + assert b[0] == 1 + assert a[0] == 0 a = asarray([1]) - if supports_copy_false_same_ns: - b = asarray(a, copy=False) - assert is_lib_func(b) - a[0] = 0 - assert all(b[0] == 0) - else: - pytest.raises(NotImplementedError, lambda: asarray(a, copy=False)) - a = asarray([1]) - if supports_copy_false_same_ns: - pytest.raises(ValueError, lambda: asarray(a, copy=False, - dtype=xp.float64)) - else: - pytest.raises(NotImplementedError, lambda: asarray(a, copy=False, dtype=xp.float64)) + # Test copy=False within the same namespace + b = asarray(a, copy=False) + assert is_lib_func(b) + a[0] = 0 + assert b[0] == 0 + with pytest.raises(ValueError): + asarray(a, copy=False, dtype=xp.float64) + # copy=None defaults to False when possible a = asarray([1]) b = asarray(a, copy=None) assert is_lib_func(b) a[0] = 0 - assert all(b[0] == 0) + assert b[0] == 0 + # copy=None defaults to True when impossible a = asarray([1.0], dtype=xp.float32) assert a.dtype == xp.float32 b = asarray(a, dtype=xp.float64, copy=None) assert is_lib_func(b) assert b.dtype == xp.float64 a[0] = 0.0 - assert all(b[0] == 1.0) + assert b[0] == 1.0 + # copy=None defaults to False when possible a = asarray([1.0], dtype=xp.float64) assert a.dtype == xp.float64 b = asarray(a, dtype=xp.float64, copy=None) assert is_lib_func(b) assert b.dtype == xp.float64 a[0] = 0.0 - assert all(b[0] == 0.0) + assert b[0] == 0.0 # Python built-in types for obj in [True, 0, 0.0, 0j, [0], [[0]]]: asarray(obj, copy=True) # No error asarray(obj, copy=None) # No error - if supports_copy_false_other_ns: - pytest.raises(ValueError, lambda: asarray(obj, copy=False)) - else: - pytest.raises(NotImplementedError, lambda: asarray(obj, copy=False)) + + with pytest.raises(ValueError): + asarray(obj, copy=False) # Use the standard library array to test the buffer protocol a = array.array('f', [1.0]) b = asarray(a, copy=True) assert is_lib_func(b) a[0] = 0.0 - assert all(b[0] == 1.0) + assert b[0] == 1.0 a = array.array('f', [1.0]) - if supports_copy_false_other_ns: + if library in ('cupy', 'dask.array'): + with pytest.raises(ValueError): + asarray(a, copy=False) + else: b = asarray(a, copy=False) assert is_lib_func(b) a[0] = 0.0 - assert all(b[0] == 0.0) - else: - pytest.raises(NotImplementedError, lambda: asarray(a, copy=False)) + assert b[0] == 0.0 a = array.array('f', [1.0]) b = asarray(a, copy=None) @@ -369,9 +355,10 @@ def test_asarray_copy(library): # dask changed behaviour of copy=None in 2024.12 to copy; # this wrapper ensures the same behaviour in older versions too. # https://github.com/dask/dask/pull/11524/ - assert all(b[0] == 1.0) + assert b[0] == 1.0 else: - assert all(b[0] == 0.0) + # copy=None defaults to False when possible + assert b[0] == 0.0 @pytest.mark.parametrize("library", ["numpy", "cupy", "torch"]) From 7c3d68c47147663399cf4f23de24b9a4193d6f65 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 15 May 2025 09:01:14 +0100 Subject: [PATCH 047/151] TST: fix cupy `to_device` test on multiple devices --- tests/test_cupy.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/tests/test_cupy.py b/tests/test_cupy.py index f8b4a4d8..fb0c69e4 100644 --- a/tests/test_cupy.py +++ b/tests/test_cupy.py @@ -8,15 +8,17 @@ def test_to_device_with_stream(): devices = xp.__array_namespace_info__().devices() streams = [ - Stream(), - Stream(non_blocking=True), - Stream(null=True), - Stream(ptds=True), - 123, # dlpack stream + lambda: Stream(), + lambda: Stream(non_blocking=True), + lambda: Stream(null=True), + lambda: Stream(ptds=True), + lambda: 123, # dlpack stream ] a = xp.asarray([1, 2, 3]) for dev in devices: - for stream in streams: + for stream_gen in streams: + with dev: + stream = stream_gen() b = to_device(a, dev, stream=stream) assert device(b) == dev From c829ef744cb04474b8eedf520557f1ca05bb77dc Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 15 May 2025 09:07:15 +0100 Subject: [PATCH 048/151] nits --- tests/test_cupy.py | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/tests/test_cupy.py b/tests/test_cupy.py index fb0c69e4..5aac36f8 100644 --- a/tests/test_cupy.py +++ b/tests/test_cupy.py @@ -5,20 +5,26 @@ from cupy.cuda import Stream -def test_to_device_with_stream(): - devices = xp.__array_namespace_info__().devices() - streams = [ +@pytest.mark.parametrize( + "make_stream", + [ lambda: Stream(), - lambda: Stream(non_blocking=True), + lambda: Stream(non_blocking=True), lambda: Stream(null=True), - lambda: Stream(ptds=True), + lambda: Stream(ptds=True), lambda: 123, # dlpack stream - ] + ], +) +def test_to_device_with_stream(make_stream): + devices = xp.__array_namespace_info__().devices() a = xp.asarray([1, 2, 3]) for dev in devices: - for stream_gen in streams: - with dev: - stream = stream_gen() - b = to_device(a, dev, stream=stream) - assert device(b) == dev + # Streams are device-specific and must be created within + # the context of the device... + with dev: + stream = make_stream() + # ... however, to_device() does not need to be inside the + # device context. + b = to_device(a, dev, stream=stream) + assert device(b) == dev From 44e7828b0666e5edd26958fb2337d755cf1a6002 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 15 May 2025 09:08:18 +0100 Subject: [PATCH 049/151] lint --- tests/test_common.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/tests/test_common.py b/tests/test_common.py index fe4fe598..54b5ed69 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -268,7 +268,6 @@ def test_asarray_cross_library(source_library, target_library, request): assert b.dtype == tgt_lib.int32 - @pytest.mark.parametrize("library", wrapped_libraries) def test_asarray_copy(library): # Note, we have this test here because the test suite currently doesn't @@ -323,21 +322,21 @@ def test_asarray_copy(library): # Python built-in types for obj in [True, 0, 0.0, 0j, [0], [[0]]]: - asarray(obj, copy=True) # No error - asarray(obj, copy=None) # No error + asarray(obj, copy=True) # No error + asarray(obj, copy=None) # No error with pytest.raises(ValueError): asarray(obj, copy=False) # Use the standard library array to test the buffer protocol - a = array.array('f', [1.0]) + a = array.array("f", [1.0]) b = asarray(a, copy=True) assert is_lib_func(b) a[0] = 0.0 assert b[0] == 1.0 - a = array.array('f', [1.0]) - if library in ('cupy', 'dask.array'): + a = array.array("f", [1.0]) + if library in ("cupy", "dask.array"): with pytest.raises(ValueError): asarray(a, copy=False) else: @@ -346,11 +345,11 @@ def test_asarray_copy(library): a[0] = 0.0 assert b[0] == 0.0 - a = array.array('f', [1.0]) + a = array.array("f", [1.0]) b = asarray(a, copy=None) assert is_lib_func(b) a[0] = 0.0 - if library in ('cupy', 'dask.array'): + if library in ("cupy", "dask.array"): # A copy is required for libraries where the default device is not CPU # dask changed behaviour of copy=None in 2024.12 to copy; # this wrapper ensures the same behaviour in older versions too. From 0433b8e94ca802d9f6402acacb81f9c4fef6f84a Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 15 May 2025 09:26:53 +0100 Subject: [PATCH 050/151] skip segmentation fault --- tests/test_cupy.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/test_cupy.py b/tests/test_cupy.py index 5aac36f8..8b71d978 100644 --- a/tests/test_cupy.py +++ b/tests/test_cupy.py @@ -12,7 +12,11 @@ lambda: Stream(non_blocking=True), lambda: Stream(null=True), lambda: Stream(ptds=True), - lambda: 123, # dlpack stream + pytest.param( + lambda: 123, + id="dlpack stream", + marks=pytest.mark.skip(reason="segmentation fault reported (#326)") + ), ], ) def test_to_device_with_stream(make_stream): From ebd3fd9356664c0502506adba96d1df72c47ec49 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 15 May 2025 09:35:08 +0100 Subject: [PATCH 051/151] Use pointers --- tests/test_cupy.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/tests/test_cupy.py b/tests/test_cupy.py index 8b71d978..4745b983 100644 --- a/tests/test_cupy.py +++ b/tests/test_cupy.py @@ -12,11 +12,6 @@ lambda: Stream(non_blocking=True), lambda: Stream(null=True), lambda: Stream(ptds=True), - pytest.param( - lambda: 123, - id="dlpack stream", - marks=pytest.mark.skip(reason="segmentation fault reported (#326)") - ), ], ) def test_to_device_with_stream(make_stream): @@ -32,3 +27,19 @@ def test_to_device_with_stream(make_stream): # device context. b = to_device(a, dev, stream=stream) assert device(b) == dev + + +def test_to_device_with_dlpack_stream(): + devices = xp.__array_namespace_info__().devices() + + a = xp.asarray([1, 2, 3]) + for dev in devices: + # Streams are device-specific and must be created within + # the context of the device... + with dev: + s1 = Stream() + + # ... however, to_device() does not need to be inside the + # device context. + b = to_device(a, dev, stream=s1.ptr) + assert device(b) == dev From 1c53eeb895c5d1ec93db82e813912589a9aa3b41 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 15 May 2025 09:42:29 +0100 Subject: [PATCH 052/151] MAINT: don't import helpers in numpy namespace --- array_api_compat/numpy/__init__.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/array_api_compat/numpy/__init__.py b/array_api_compat/numpy/__init__.py index f7b558ba..8eab0405 100644 --- a/array_api_compat/numpy/__init__.py +++ b/array_api_compat/numpy/__init__.py @@ -23,13 +23,6 @@ __import__(__package__ + ".fft") -from ..common._helpers import * # noqa: F403 -from .linalg import matrix_transpose, vecdot # noqa: F401 - -try: - # Used in asarray(). Not present in older versions. - from numpy import _CopyMode # noqa: F401 -except ImportError: - pass +from .linalg import matrix_transpose, vecdot # type: ignore[no-redef] # noqa: F401 __array_api_version__: Final = "2024.12" From e945af9debb715da60b807c028776a4e3d1a0c52 Mon Sep 17 00:00:00 2001 From: Guido Imperiale Date: Thu, 15 May 2025 09:45:10 +0100 Subject: [PATCH 053/151] Update array_api_compat/numpy/__init__.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- array_api_compat/numpy/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_compat/numpy/__init__.py b/array_api_compat/numpy/__init__.py index 8eab0405..3e138f53 100644 --- a/array_api_compat/numpy/__init__.py +++ b/array_api_compat/numpy/__init__.py @@ -23,6 +23,6 @@ __import__(__package__ + ".fft") -from .linalg import matrix_transpose, vecdot # type: ignore[no-redef] # noqa: F401 +from .linalg import matrix_transpose, vecdot # type: ignore[no-redef] # noqa: F401 __array_api_version__: Final = "2024.12" From 6b3ec935eb325d443c327d6490e16e69f273da06 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 12 May 2025 13:48:46 +0200 Subject: [PATCH 054/151] DOC: add 1.12 changelog --- docs/changelog.md | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/docs/changelog.md b/docs/changelog.md index 18928e98..c2d5b2c5 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -1,5 +1,42 @@ # Changelog +## 1.12.0 (2025-05-13) + + +### Major changes + +- The build system has been updated to use `pyproject.toml` instead of `setup.py` +- Support for Python 3.9 has been dropped. The minimum supported Python version is now + 3.10; the minimum supported NumPy version is 1.22; the minimum supported `ndonnx` + version is 0.10.1. +- The `linalg` extension works correctly with `pytorch==2.7`. +- Multiple improvements to handling of `device` arguments in `numpy`, `cupy`, `torch`, + and `dask` backends. Support for multiple devices is still relatively immature, + and rough edges can be expected. Please report any issues you encounter. + +### Minor changes + +- `finfo` and `iinfo` functions now accept array arguments, in accordance with the + Array API spec; +- `torch.asarray` function propagates the device of the input array. This works around + the [pytorch issue #150199](https://github.com/pytorch/pytorch/issues/150199); +- `torch.repeat` function is now available; +- `torch.count_nonzero` function now correctly handles the case of a tuple `axis` + arguments and `keepdims=True`; +- `torch.meshgrid` wrapper defaults to `indexing="xy"`, in accordance with the + array API specification; +- `cupy.asarray` function now implements the `copy=True` argument; + + +The following users contributed to this release: + +Evgeni Burovski, +Lucas Colley, +Neil Girdhar, +Joren Hammudoglu, +Guido Imperiale + + ## 1.11.2 (2025-03-20) This is a bugfix release with no new features compared to version 1.11. From 97e3cc5b1b32bd0a0d5c2a9810df9145012992ed Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Tue, 13 May 2025 15:45:34 +0200 Subject: [PATCH 055/151] MAINT: update numpy 1.22 xfails --- numpy-1-22-xfails.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/numpy-1-22-xfails.txt b/numpy-1-22-xfails.txt index cacb95b7..e0b96c61 100644 --- a/numpy-1-22-xfails.txt +++ b/numpy-1-22-xfails.txt @@ -131,6 +131,7 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_sc array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[nextafter] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[multiply] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[maximum] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[minimum] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[copysign] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[pow] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[pow] From 34e2c6f2799e7f0237c035f61b8f1891baae02d7 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Thu, 15 May 2025 11:47:52 +0200 Subject: [PATCH 056/151] DOC: update the changelog Co-authored-by: Guido Imperiale --- docs/changelog.md | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/docs/changelog.md b/docs/changelog.md index c2d5b2c5..c00c62db 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -7,12 +7,12 @@ - The build system has been updated to use `pyproject.toml` instead of `setup.py` - Support for Python 3.9 has been dropped. The minimum supported Python version is now - 3.10; the minimum supported NumPy version is 1.22; the minimum supported `ndonnx` - version is 0.10.1. -- The `linalg` extension works correctly with `pytorch==2.7`. -- Multiple improvements to handling of `device` arguments in `numpy`, `cupy`, `torch`, - and `dask` backends. Support for multiple devices is still relatively immature, - and rough edges can be expected. Please report any issues you encounter. + 3.10; the minimum supported NumPy version is 1.22. +- The `linalg` extension works correctly with `pytorch>=2.7`. +- Multiple improvements to handling of devices in CuPy and PyTorch backends. + Support for multiple devices in CuPy is still immature and you should use + context managers rather than relying on input-output device propagation or + on the `device` parameter. ### Minor changes @@ -25,7 +25,10 @@ arguments and `keepdims=True`; - `torch.meshgrid` wrapper defaults to `indexing="xy"`, in accordance with the array API specification; -- `cupy.asarray` function now implements the `copy=True` argument; +- `cupy.asarray` function now implements the `copy=False` argument, albeit + at the cost of risking to make a temporary copy. +- In `numpy.take_along_axis` and `cupy.take_along_axis` the `axis` parameter now + defaults to -1, in accordance to the Array API spec. The following users contributed to this release: From cdd1213ea28af34b721d105a25d4b7ff2414ef18 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Thu, 15 May 2025 11:48:47 +0200 Subject: [PATCH 057/151] Update docs/changelog.md --- docs/changelog.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/changelog.md b/docs/changelog.md index c00c62db..6f6c1251 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -12,7 +12,7 @@ - Multiple improvements to handling of devices in CuPy and PyTorch backends. Support for multiple devices in CuPy is still immature and you should use context managers rather than relying on input-output device propagation or - on the `device` parameter. + on the `device` parameter. Please report any issues you encounter. ### Minor changes From 26a1d2016517ae3bb86ddfef137247fa15ddb512 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Thu, 15 May 2025 12:02:16 +0200 Subject: [PATCH 058/151] MAINT: update CuPy xfails --- cupy-xfails.txt | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/cupy-xfails.txt b/cupy-xfails.txt index 77def129..0a91cafe 100644 --- a/cupy-xfails.txt +++ b/cupy-xfails.txt @@ -39,6 +39,11 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_sc array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[logaddexp] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[nextafter] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[hypot] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[atan2] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[divide] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[copysign] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[divide] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[maximum] # cupy (arg)min/max wrong with infinities # https://github.com/cupy/cupy/issues/7424 From 8005d6d02c0f1717881de37a710871bb955eb5cd Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Tue, 13 May 2025 14:31:01 +0200 Subject: [PATCH 059/151] REL: bump the version to 1.12.0 --- array_api_compat/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_compat/__init__.py b/array_api_compat/__init__.py index 60b37e97..653cb40a 100644 --- a/array_api_compat/__init__.py +++ b/array_api_compat/__init__.py @@ -17,6 +17,6 @@ this implementation for the default when working with NumPy arrays. """ -__version__ = '1.12.dev0' +__version__ = '1.12.0' from .common import * # noqa: F401, F403 From 91dd626ce8b2612979e513af235be3809791f94b Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Fri, 16 May 2025 10:58:43 +0200 Subject: [PATCH 060/151] MAINT: bump version to 1.13.0.dev0 --- array_api_compat/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_compat/__init__.py b/array_api_compat/__init__.py index 653cb40a..a00e8cbc 100644 --- a/array_api_compat/__init__.py +++ b/array_api_compat/__init__.py @@ -17,6 +17,6 @@ this implementation for the default when working with NumPy arrays. """ -__version__ = '1.12.0' +__version__ = '1.13.0.dev0' from .common import * # noqa: F401, F403 From 3350f670e1b67a37888228c102e0e560f43077bd Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sun, 11 May 2025 14:12:51 +0200 Subject: [PATCH 061/151] CI: run 500 examples on NumPy and PyTorch; 50 on Dask --- .github/workflows/array-api-tests-dask.yml | 2 +- .github/workflows/array-api-tests.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/array-api-tests-dask.yml b/.github/workflows/array-api-tests-dask.yml index 964fb52d..a60b28a4 100644 --- a/.github/workflows/array-api-tests-dask.yml +++ b/.github/workflows/array-api-tests-dask.yml @@ -14,7 +14,7 @@ jobs: # workflow is barely more than a smoke test, and one should expect extreme # flakiness. Before changes to dask-xfails.txt or dask-skips.txt, please run # the full test suite with at least 200 examples. - pytest-extra-args: --max-examples=5 + pytest-extra-args: --max-examples=50 python-versions: '[''3.10'', ''3.13'']' extra-env-vars: | ARRAY_API_TESTS_XFAIL_MARK=skip diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index 31bedde6..f652438b 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -37,7 +37,7 @@ on: description: Multiline string of environment variables to set for the test run. env: - PYTEST_ARGS: "--max-examples 200 -v -rxXfE --ci ${{ inputs.pytest-extra-args }} --hypothesis-disable-deadline --durations 10" + PYTEST_ARGS: "--max-examples 500 -v -rxXfE --ci ${{ inputs.pytest-extra-args }} --hypothesis-disable-deadline --durations 20" jobs: tests: From 5b1ece468fb9b9b789304b57035de6801a39c7b1 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 12 May 2025 14:18:30 +0200 Subject: [PATCH 062/151] CI: use 4 workers; bump the # of examples to 1000 (np/torch), 200 (dask) --- .github/workflows/array-api-tests-dask.yml | 2 +- .github/workflows/array-api-tests-numpy-1-22.yml | 1 + .github/workflows/array-api-tests-numpy-1-26.yml | 1 + .github/workflows/array-api-tests-numpy-dev.yml | 1 + .github/workflows/array-api-tests-numpy-latest.yml | 1 + .github/workflows/array-api-tests-torch.yml | 1 + .github/workflows/array-api-tests.yml | 3 ++- 7 files changed, 8 insertions(+), 2 deletions(-) diff --git a/.github/workflows/array-api-tests-dask.yml b/.github/workflows/array-api-tests-dask.yml index a60b28a4..ef430d9c 100644 --- a/.github/workflows/array-api-tests-dask.yml +++ b/.github/workflows/array-api-tests-dask.yml @@ -14,7 +14,7 @@ jobs: # workflow is barely more than a smoke test, and one should expect extreme # flakiness. Before changes to dask-xfails.txt or dask-skips.txt, please run # the full test suite with at least 200 examples. - pytest-extra-args: --max-examples=50 + pytest-extra-args: --max-examples=200 -n 4 python-versions: '[''3.10'', ''3.13'']' extra-env-vars: | ARRAY_API_TESTS_XFAIL_MARK=skip diff --git a/.github/workflows/array-api-tests-numpy-1-22.yml b/.github/workflows/array-api-tests-numpy-1-22.yml index 1cf6e26d..83d4cf1d 100644 --- a/.github/workflows/array-api-tests-numpy-1-22.yml +++ b/.github/workflows/array-api-tests-numpy-1-22.yml @@ -10,5 +10,6 @@ jobs: package-version: '== 1.22.*' xfails-file-extra: '-1-22' python-versions: '[''3.10'']' + pytest-extra-args: -n 4 extra-env-vars: | ARRAY_API_TESTS_XFAIL_MARK=skip diff --git a/.github/workflows/array-api-tests-numpy-1-26.yml b/.github/workflows/array-api-tests-numpy-1-26.yml index a2788d2f..13124644 100644 --- a/.github/workflows/array-api-tests-numpy-1-26.yml +++ b/.github/workflows/array-api-tests-numpy-1-26.yml @@ -10,5 +10,6 @@ jobs: package-version: '== 1.26.*' xfails-file-extra: '-1-26' python-versions: '[''3.10'', ''3.12'']' + pytest-extra-args: -n 4 extra-env-vars: | ARRAY_API_TESTS_XFAIL_MARK=skip diff --git a/.github/workflows/array-api-tests-numpy-dev.yml b/.github/workflows/array-api-tests-numpy-dev.yml index dce0813f..dec4c7ae 100644 --- a/.github/workflows/array-api-tests-numpy-dev.yml +++ b/.github/workflows/array-api-tests-numpy-dev.yml @@ -10,5 +10,6 @@ jobs: extra-requires: '--pre --extra-index https://pypi.anaconda.org/scientific-python-nightly-wheels/simple' xfails-file-extra: '-dev' python-versions: '[''3.11'', ''3.13'']' + pytest-extra-args: -n 4 extra-env-vars: | ARRAY_API_TESTS_XFAIL_MARK=skip diff --git a/.github/workflows/array-api-tests-numpy-latest.yml b/.github/workflows/array-api-tests-numpy-latest.yml index 54b21a25..65bbc9a2 100644 --- a/.github/workflows/array-api-tests-numpy-latest.yml +++ b/.github/workflows/array-api-tests-numpy-latest.yml @@ -8,5 +8,6 @@ jobs: with: package-name: numpy python-versions: '[''3.10'', ''3.13'']' + pytest-extra-args: -n 4 extra-env-vars: | ARRAY_API_TESTS_XFAIL_MARK=skip diff --git a/.github/workflows/array-api-tests-torch.yml b/.github/workflows/array-api-tests-torch.yml index 4dcb3347..4b4b945e 100644 --- a/.github/workflows/array-api-tests-torch.yml +++ b/.github/workflows/array-api-tests-torch.yml @@ -12,3 +12,4 @@ jobs: ARRAY_API_TESTS_SKIP_DTYPES=uint16,uint32,uint64 ARRAY_API_TESTS_XFAIL_MARK=skip python-versions: '[''3.10'', ''3.13'']' + pytest-extra-args: -n 4 diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index f652438b..53c1474d 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -37,7 +37,7 @@ on: description: Multiline string of environment variables to set for the test run. env: - PYTEST_ARGS: "--max-examples 500 -v -rxXfE --ci ${{ inputs.pytest-extra-args }} --hypothesis-disable-deadline --durations 20" + PYTEST_ARGS: "--max-examples 1000 -v -rxXfE --ci ${{ inputs.pytest-extra-args }} --hypothesis-disable-deadline --durations 20" jobs: tests: @@ -76,6 +76,7 @@ jobs: python -m pip install --upgrade pip python -m pip install '${{ inputs.package-name }} ${{ inputs.package-version }}' ${{ inputs.extra-requires }} python -m pip install -r ${GITHUB_WORKSPACE}/array-api-tests/requirements.txt + python -m pip install pytest-xdist - name: Dump pip environment run: pip freeze From 37d5d668674f10ae41709a7ebc3d2ab2ae6b25c4 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 12 May 2025 16:42:01 +0200 Subject: [PATCH 063/151] MAINT: update numpy-1.22 xfails --- numpy-1-22-xfails.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/numpy-1-22-xfails.txt b/numpy-1-22-xfails.txt index cacb95b7..d4022b31 100644 --- a/numpy-1-22-xfails.txt +++ b/numpy-1-22-xfails.txt @@ -133,7 +133,7 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_sc array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[maximum] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[copysign] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[pow] -array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[pow] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[minimum] array_api_tests/test_operators_and_elementwise_functions.py::test_where_with_scalars From 43435808041951df2c7b7cae28204b3ce61f6e46 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 12 May 2025 16:42:29 +0200 Subject: [PATCH 064/151] MAINT: remove --ci pytest switch The warning says it's deprecated. --- .github/workflows/array-api-tests.yml | 2 +- numpy-1-22-xfails.txt | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index 53c1474d..e832f870 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -37,7 +37,7 @@ on: description: Multiline string of environment variables to set for the test run. env: - PYTEST_ARGS: "--max-examples 1000 -v -rxXfE --ci ${{ inputs.pytest-extra-args }} --hypothesis-disable-deadline --durations 20" + PYTEST_ARGS: "--max-examples 1000 -v -rxXfE ${{ inputs.pytest-extra-args }} --hypothesis-disable-deadline --durations 20" jobs: tests: diff --git a/numpy-1-22-xfails.txt b/numpy-1-22-xfails.txt index d4022b31..e1c4f832 100644 --- a/numpy-1-22-xfails.txt +++ b/numpy-1-22-xfails.txt @@ -134,6 +134,7 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_sc array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[copysign] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[pow] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[minimum] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[atan2] array_api_tests/test_operators_and_elementwise_functions.py::test_where_with_scalars From c03daa36c09d51162d240b77e223a49cc8a6076e Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sun, 1 Jun 2025 13:36:11 +0200 Subject: [PATCH 065/151] CI: install jax/sparse/torch in more jobs Also, `ndonnx` has wheels for all python versions now; And we do not bother with jax or dask numpy < 1. --- .github/workflows/tests.yml | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 81a05b3f..c995b370 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -32,20 +32,20 @@ jobs: python -m pip install --upgrade pip python -m pip install pytest + # Don't `pip install .[dev]` as it would pull in the whole torch cuda stack + python -m pip install array-api-strict + python -m pip install torch --index-url https://download.pytorch.org/whl/cpu + if [ "${{ matrix.numpy-version }}" == "dev" ]; then python -m pip install numpy --pre --extra-index-url https://pypi.anaconda.org/scientific-python-nightly-wheels/simple + python -m pip install dask[array] jax[cpu] sparse ndonnx elif [ "${{ matrix.numpy-version }}" == "1.22" ]; then python -m pip install 'numpy==1.22.*' elif [ "${{ matrix.numpy-version }}" == "1.26" ]; then python -m pip install 'numpy==1.26.*' else - # Don't `pip install .[dev]` as it would pull in the whole torch cuda stack - python -m pip install array-api-strict dask[array] jax[cpu] numpy sparse - python -m pip install torch --index-url https://download.pytorch.org/whl/cpu - if [ "${{ matrix.python-version }}" != "3.13" ]; then - # onnx wheels are not available on Python 3.13 at the moment of writing - python -m pip install ndonnx - fi + python -m pip install numpy + python -m pip install dask[array] jax[cpu] sparse ndonnx fi - name: Dump pip environment From a8e19835092335ab8e1846f1e3dda335d8eb4c4a Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sun, 1 Jun 2025 13:38:31 +0200 Subject: [PATCH 066/151] TST: xfail test_device_to_device with numpy < 2 It assumes that asarray has the copy kwarg, and this is not true in NumPy < 2. --- tests/test_common.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test_common.py b/tests/test_common.py index 54b5ed69..85ed032e 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -195,6 +195,9 @@ def test_device_to_device(library, request): xfail(request, reason="Stub raises ValueError") if library == "sparse": xfail(request, reason="No __array_namespace_info__()") + if library == "array_api_strict": + if np.__version__ < "2": + xfail(request, reason="no copy argument of np.asarray") xp = import_(library, wrapper=True) devices = xp.__array_namespace_info__().devices() From 8e3ab3e7c5c6794f66196ec435d2f6bdd1492404 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sun, 1 Jun 2025 13:39:28 +0200 Subject: [PATCH 067/151] MAINT: filter out some warning noise --- tests/test_array_namespace.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/tests/test_array_namespace.py b/tests/test_array_namespace.py index cdb80007..2fbb0339 100644 --- a/tests/test_array_namespace.py +++ b/tests/test_array_namespace.py @@ -23,7 +23,9 @@ def test_array_namespace(library, api_version, use_compat): if library == "ndonnx" and api_version in ("2021.12", "2022.12"): pytest.skip("Unsupported API version") - namespace = array_namespace(array, api_version=api_version, use_compat=use_compat) + with warnings.catch_warnings(): + warnings.simplefilter('ignore', UserWarning) + namespace = array_namespace(array, api_version=api_version, use_compat=use_compat) if use_compat is False or use_compat is None and library not in wrapped_libraries: if library == "jax.numpy" and use_compat is None: @@ -45,10 +47,13 @@ def test_array_namespace(library, api_version, use_compat): if library == "numpy": # check that the same namespace is returned for NumPy scalars - scalar_namespace = array_namespace( - xp.float64(0.0), api_version=api_version, use_compat=use_compat - ) - assert scalar_namespace == namespace + with warnings.catch_warnings(): + warnings.simplefilter('ignore', UserWarning) + + scalar_namespace = array_namespace( + xp.float64(0.0), api_version=api_version, use_compat=use_compat + ) + assert scalar_namespace == namespace # Check that array_namespace works even if jax.experimental.array_api # hasn't been imported yet (it monkeypatches __array_namespace__ @@ -97,7 +102,9 @@ def test_api_version_torch(): torch = import_("torch") x = torch.asarray([1, 2]) torch_ = import_("torch", wrapper=True) - assert array_namespace(x, api_version="2023.12") == torch_ + with warnings.catch_warnings(): + warnings.simplefilter('ignore', UserWarning) + assert array_namespace(x, api_version="2023.12") == torch_ assert array_namespace(x, api_version=None) == torch_ assert array_namespace(x) == torch_ # Should issue a warning From 9959873e351ecab696538650893afc2faef17a38 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 2 Jun 2025 10:10:33 +0000 Subject: [PATCH 068/151] Bump dawidd6/action-download-artifact from 9 to 10 in the actions group Bumps the actions group with 1 update: [dawidd6/action-download-artifact](https://github.com/dawidd6/action-download-artifact). Updates `dawidd6/action-download-artifact` from 9 to 10 - [Release notes](https://github.com/dawidd6/action-download-artifact/releases) - [Commits](https://github.com/dawidd6/action-download-artifact/compare/v9...v10) --- updated-dependencies: - dependency-name: dawidd6/action-download-artifact dependency-version: '10' dependency-type: direct:production update-type: version-update:semver-major dependency-group: actions ... Signed-off-by: dependabot[bot] --- .github/workflows/docs-deploy.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/docs-deploy.yml b/.github/workflows/docs-deploy.yml index fc612588..4e3efb39 100644 --- a/.github/workflows/docs-deploy.yml +++ b/.github/workflows/docs-deploy.yml @@ -13,7 +13,7 @@ jobs: steps: - uses: actions/checkout@v4 - name: Download Artifact - uses: dawidd6/action-download-artifact@v9 + uses: dawidd6/action-download-artifact@v10 with: workflow: docs-build.yml name: docs-build From 6ae28ee9538820ae09ba45d8ef3d15d4a6570900 Mon Sep 17 00:00:00 2001 From: Guido Imperiale Date: Mon, 2 Jun 2025 18:38:34 +0100 Subject: [PATCH 069/151] ENH: speed up `array_namespace` * ENH: speed up `array_namespace` * jax 0.6.1 Reviewed at https://github.com/data-apis/array-api-compat/pull/329 --- array_api_compat/common/_helpers.py | 196 +++++++++++++++------------- tests/test_array_namespace.py | 121 ++++++++--------- 2 files changed, 161 insertions(+), 156 deletions(-) diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 77175d0d..a152e4c0 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -8,6 +8,7 @@ from __future__ import annotations +import enum import inspect import math import sys @@ -485,6 +486,86 @@ def _check_api_version(api_version: str | None) -> None: ) +class _ClsToXPInfo(enum.Enum): + SCALAR = 0 + MAYBE_JAX_ZERO_GRADIENT = 1 + + +@lru_cache(100) +def _cls_to_namespace( + cls: type, + api_version: str | None, + use_compat: bool | None, +) -> tuple[Namespace | None, _ClsToXPInfo | None]: + if use_compat not in (None, True, False): + raise ValueError("use_compat must be None, True, or False") + _use_compat = use_compat in (None, True) + cls_ = cast(Hashable, cls) # Make mypy happy + + if ( + _issubclass_fast(cls_, "numpy", "ndarray") + or _issubclass_fast(cls_, "numpy", "generic") + ): + if use_compat is True: + _check_api_version(api_version) + from .. import numpy as xp + elif use_compat is False: + import numpy as xp # type: ignore[no-redef] + else: + # NumPy 2.0+ have __array_namespace__; however they are not + # yet fully array API compatible. + from .. import numpy as xp # type: ignore[no-redef] + return xp, _ClsToXPInfo.MAYBE_JAX_ZERO_GRADIENT + + # Note: this must happen _after_ the test for np.generic, + # because np.float64 and np.complex128 are subclasses of float and complex. + if issubclass(cls, int | float | complex | type(None)): + return None, _ClsToXPInfo.SCALAR + + if _issubclass_fast(cls_, "cupy", "ndarray"): + if _use_compat: + _check_api_version(api_version) + from .. import cupy as xp # type: ignore[no-redef] + else: + import cupy as xp # type: ignore[no-redef] + return xp, None + + if _issubclass_fast(cls_, "torch", "Tensor"): + if _use_compat: + _check_api_version(api_version) + from .. import torch as xp # type: ignore[no-redef] + else: + import torch as xp # type: ignore[no-redef] + return xp, None + + if _issubclass_fast(cls_, "dask.array", "Array"): + if _use_compat: + _check_api_version(api_version) + from ..dask import array as xp # type: ignore[no-redef] + else: + import dask.array as xp # type: ignore[no-redef] + return xp, None + + # Backwards compatibility for jax<0.4.32 + if _issubclass_fast(cls_, "jax", "Array"): + return _jax_namespace(api_version, use_compat), None + + return None, None + + +def _jax_namespace(api_version: str | None, use_compat: bool | None) -> Namespace: + if use_compat: + raise ValueError("JAX does not have an array-api-compat wrapper") + import jax.numpy as jnp + if not hasattr(jnp, "__array_namespace_info__"): + # JAX v0.4.32 and newer implements the array API directly in jax.numpy. + # For older JAX versions, it is available via jax.experimental.array_api. + # jnp.Array objects gain the __array_namespace__ method. + import jax.experimental.array_api # noqa: F401 + # Test api_version + return jnp.empty(0).__array_namespace__(api_version=api_version) + + def array_namespace( *xs: Array | complex | None, api_version: str | None = None, @@ -553,105 +634,40 @@ def your_function(x, y): is_pydata_sparse_array """ - if use_compat not in [None, True, False]: - raise ValueError("use_compat must be None, True, or False") - - _use_compat = use_compat in [None, True] - namespaces: set[Namespace] = set() for x in xs: - if is_numpy_array(x): - import numpy as np - - from .. import numpy as numpy_namespace - - if use_compat is True: - _check_api_version(api_version) - namespaces.add(numpy_namespace) - elif use_compat is False: - namespaces.add(np) - else: - # numpy 2.0+ have __array_namespace__, however, they are not yet fully array API - # compatible. - namespaces.add(numpy_namespace) - elif is_cupy_array(x): - if _use_compat: - _check_api_version(api_version) - from .. import cupy as cupy_namespace - - namespaces.add(cupy_namespace) - else: - import cupy as cp # pyright: ignore[reportMissingTypeStubs] - - namespaces.add(cp) - elif is_torch_array(x): - if _use_compat: - _check_api_version(api_version) - from .. import torch as torch_namespace - - namespaces.add(torch_namespace) - else: - import torch - - namespaces.add(torch) - elif is_dask_array(x): - if _use_compat: - _check_api_version(api_version) - from ..dask import array as dask_namespace - - namespaces.add(dask_namespace) - else: - import dask.array as da - - namespaces.add(da) - elif is_jax_array(x): - if use_compat is True: - _check_api_version(api_version) - raise ValueError("JAX does not have an array-api-compat wrapper") - elif use_compat is False: - import jax.numpy as jnp - else: - # JAX v0.4.32 and newer implements the array API directly in jax.numpy. - # For older JAX versions, it is available via jax.experimental.array_api. - import jax.numpy - - if hasattr(jax.numpy, "__array_api_version__"): - jnp = jax.numpy - else: - import jax.experimental.array_api as jnp # pyright: ignore[reportMissingImports] - namespaces.add(jnp) - elif is_pydata_sparse_array(x): - if use_compat is True: - _check_api_version(api_version) - raise ValueError("`sparse` does not have an array-api-compat wrapper") - else: - import sparse # pyright: ignore[reportMissingTypeStubs] - # `sparse` is already an array namespace. We do not have a wrapper - # submodule for it. - namespaces.add(sparse) - elif hasattr(x, "__array_namespace__"): - if use_compat is True: + xp, info = _cls_to_namespace(cast(Hashable, type(x)), api_version, use_compat) + if info is _ClsToXPInfo.SCALAR: + continue + + if ( + info is _ClsToXPInfo.MAYBE_JAX_ZERO_GRADIENT + and _is_jax_zero_gradient_array(x) + ): + xp = _jax_namespace(api_version, use_compat) + + if xp is None: + get_ns = getattr(x, "__array_namespace__", None) + if get_ns is None: + raise TypeError(f"{type(x).__name__} is not a supported array type") + if use_compat: raise ValueError( "The given array does not have an array-api-compat wrapper" ) - x = cast("SupportsArrayNamespace[Any]", x) - namespaces.add(x.__array_namespace__(api_version=api_version)) - elif isinstance(x, (bool, int, float, complex, type(None))): - continue - else: - # TODO: Support Python scalars? - raise TypeError(f"{type(x).__name__} is not a supported array type") + xp = get_ns(api_version=api_version) - if not namespaces: - raise TypeError("Unrecognized array input") + namespaces.add(xp) - if len(namespaces) != 1: + try: + (xp,) = namespaces + return xp + except ValueError: + if not namespaces: + raise TypeError( + "array_namespace requires at least one non-scalar array input" + ) raise TypeError(f"Multiple namespaces for array inputs: {namespaces}") - (xp,) = namespaces - - return xp - # backwards compatibility alias get_namespace = array_namespace diff --git a/tests/test_array_namespace.py b/tests/test_array_namespace.py index 2fbb0339..311efc37 100644 --- a/tests/test_array_namespace.py +++ b/tests/test_array_namespace.py @@ -8,42 +8,41 @@ import array_api_compat from array_api_compat import array_namespace -from ._helpers import import_, all_libraries, wrapped_libraries +from ._helpers import all_libraries, wrapped_libraries, xfail + @pytest.mark.parametrize("use_compat", [True, False, None]) -@pytest.mark.parametrize("api_version", [None, "2021.12", "2022.12", "2023.12"]) +@pytest.mark.parametrize( + "api_version", [None, "2021.12", "2022.12", "2023.12", "2024.12"] +) @pytest.mark.parametrize("library", all_libraries) -def test_array_namespace(library, api_version, use_compat): - xp = import_(library) +def test_array_namespace(request, library, api_version, use_compat): + xp = pytest.importorskip(library) array = xp.asarray([1.0, 2.0, 3.0]) if use_compat and library not in wrapped_libraries: pytest.raises(ValueError, lambda: array_namespace(array, use_compat=use_compat)) return - if library == "ndonnx" and api_version in ("2021.12", "2022.12"): - pytest.skip("Unsupported API version") + if (library == "sparse" and api_version in ("2023.12", "2024.12")) or ( + library == "jax.numpy" and api_version in ("2021.12", "2022.12", "2023.12") + ): + xfail(request, "Unsupported API version") with warnings.catch_warnings(): warnings.simplefilter('ignore', UserWarning) namespace = array_namespace(array, api_version=api_version, use_compat=use_compat) if use_compat is False or use_compat is None and library not in wrapped_libraries: - if library == "jax.numpy" and use_compat is None: - import jax.numpy - if hasattr(jax.numpy, "__array_api_version__"): - # JAX v0.4.32 or later uses jax.numpy directly - assert namespace == jax.numpy - else: - # JAX v0.4.31 or earlier uses jax.experimental.array_api - import jax.experimental.array_api - assert namespace == jax.experimental.array_api + if library == "jax.numpy" and not hasattr(xp, "__array_api_version__"): + # Backwards compatibility for JAX <0.4.32 + import jax.experimental.array_api + assert namespace == jax.experimental.array_api else: assert namespace == xp + elif library == "dask.array": + assert namespace == array_api_compat.dask.array else: - if library == "dask.array": - assert namespace == array_api_compat.dask.array - else: - assert namespace == getattr(array_api_compat, library) + assert namespace == getattr(array_api_compat, library) if library == "numpy": # check that the same namespace is returned for NumPy scalars @@ -55,20 +54,20 @@ def test_array_namespace(library, api_version, use_compat): ) assert scalar_namespace == namespace - # Check that array_namespace works even if jax.experimental.array_api - # hasn't been imported yet (it monkeypatches __array_namespace__ - # onto JAX arrays, but we should support them regardless). The only way to - # do this is to use a subprocess, since we cannot un-import it and another - # test probably already imported it. - if library == "jax.numpy" and sys.version_info >= (3, 9): - code = f"""\ + +def test_jax_backwards_compat(): + """On JAX <0.4.32, test that array_namespace works even if + jax.experimental.array_api has not been imported yet. + """ + pytest.importorskip("jax") + code = """\ import sys import jax.numpy import array_api_compat -array = jax.numpy.asarray([1.0, 2.0, 3.0]) +array = jax.numpy.asarray([1.0, 2.0, 3.0]) assert 'jax.experimental.array_api' not in sys.modules -namespace = array_api_compat.array_namespace(array, api_version={api_version!r}) +namespace = array_api_compat.array_namespace(array) if hasattr(jax.numpy, '__array_api_version__'): assert namespace == jax.numpy @@ -76,14 +75,16 @@ def test_array_namespace(library, api_version, use_compat): import jax.experimental.array_api assert namespace == jax.experimental.array_api """ - subprocess.run([sys.executable, "-c", code], check=True) + subprocess.check_call([sys.executable, "-c", code]) + def test_jax_zero_gradient(): - jax = import_("jax") + jax = pytest.importorskip("jax") jx = jax.numpy.arange(4) jax_zero = jax.vmap(jax.grad(jax.numpy.float32, allow_int=True))(jx) assert array_namespace(jax_zero) is array_namespace(jx) + def test_array_namespace_errors(): pytest.raises(TypeError, lambda: array_namespace([1])) pytest.raises(TypeError, lambda: array_namespace()) @@ -92,43 +93,31 @@ def test_array_namespace_errors(): pytest.raises(TypeError, lambda: array_namespace((x, x))) pytest.raises(TypeError, lambda: array_namespace(x, (x, x))) -def test_array_namespace_errors_torch(): - torch = import_("torch") - y = torch.asarray([1, 2]) - x = np.asarray([1, 2]) - pytest.raises(TypeError, lambda: array_namespace(x, y)) -def test_api_version_torch(): - torch = import_("torch") - x = torch.asarray([1, 2]) - torch_ = import_("torch", wrapper=True) - with warnings.catch_warnings(): - warnings.simplefilter('ignore', UserWarning) - assert array_namespace(x, api_version="2023.12") == torch_ - assert array_namespace(x, api_version=None) == torch_ - assert array_namespace(x) == torch_ - # Should issue a warning - with warnings.catch_warnings(record=True) as w: - assert array_namespace(x, api_version="2021.12") == torch_ - assert len(w) == 1 - assert "2021.12" in str(w[0].message) - - # Should issue a warning - with warnings.catch_warnings(record=True) as w: - assert array_namespace(x, api_version="2022.12") == torch_ - assert len(w) == 1 - assert "2022.12" in str(w[0].message) - - pytest.raises(ValueError, lambda: array_namespace(x, api_version="2020.12")) +@pytest.mark.parametrize("library", all_libraries) +def test_array_namespace_many_args(library): + xp = pytest.importorskip(library) + a = xp.asarray(1) + b = xp.asarray(2) + assert array_namespace(a, b) is array_namespace(a) + + +def test_array_namespace_mismatch(): + xp = pytest.importorskip("array_api_strict") + with pytest.raises(TypeError, match="Multiple namespaces"): + array_namespace(np.asarray(1), xp.asarray(1)) + def test_get_namespace(): # Backwards compatible wrapper assert array_api_compat.get_namespace is array_namespace -def test_python_scalars(): - torch = import_("torch") - a = torch.asarray([1, 2]) - xp = import_("torch", wrapper=True) + +@pytest.mark.parametrize("library", all_libraries) +def test_python_scalars(library): + xp = pytest.importorskip(library) + a = xp.asarray([1, 2]) + xp = array_namespace(a) pytest.raises(TypeError, lambda: array_namespace(1)) pytest.raises(TypeError, lambda: array_namespace(1.0)) @@ -136,8 +125,8 @@ def test_python_scalars(): pytest.raises(TypeError, lambda: array_namespace(True)) pytest.raises(TypeError, lambda: array_namespace(None)) - assert array_namespace(a, 1) == xp - assert array_namespace(a, 1.0) == xp - assert array_namespace(a, 1j) == xp - assert array_namespace(a, True) == xp - assert array_namespace(a, None) == xp + assert array_namespace(a, 1) is xp + assert array_namespace(a, 1.0) is xp + assert array_namespace(a, 1j) is xp + assert array_namespace(a, True) is xp + assert array_namespace(a, None) is xp From c9cfc2c9193fcdf0e52a2bbdace54182780839c9 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sun, 1 Jun 2025 11:37:49 +0200 Subject: [PATCH 070/151] TST: add a test that wrapping preserves a view/copy semantics for unary functions If a bare library returns a copy, so does the wrapped library; if the bare library returns a view, so does the wrapped library. --- tests/test_copies_or_views.py | 66 +++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) create mode 100644 tests/test_copies_or_views.py diff --git a/tests/test_copies_or_views.py b/tests/test_copies_or_views.py new file mode 100644 index 00000000..5b9b9207 --- /dev/null +++ b/tests/test_copies_or_views.py @@ -0,0 +1,66 @@ +""" +A collection of tests to make sure that wrapped namespaces agree with the bare ones +on whether to return a view or a copy of inputs. +""" +import pytest +from ._helpers import import_ + + +LIB_NAMES = ['numpy', 'cupy', 'torch', 'dask.array', 'array_api_strict'] + +FUNC_INPUTS = [ + # func_name, arr_input, dtype, scalar_value + ('abs', [1, 2], 'int8', 3), + ('abs', [1, 2], 'float32', 3.), + ('ceil', [1, 2], 'int8', 3), + ('clip', [1, 2], 'int8', 3), + ('conj', [1, 2], 'int8', 3), + ('floor', [1, 2], 'int8', 3), + ('imag', [1j, 2j], 'complex64', 3), + ('positive', [1, 2], 'int8', 3), + ('real', [1., 2.], 'float32', 3.), + ('round', [1, 2], 'int8', 3), + ('sign', [0, 0], 'float32', 3), + ('trunc', [1, 2], 'int8', 3), + ('trunc', [1, 2], 'float32', 3), +] + + +def ensure_unary(func, arr): + """Make a trivial unary function from func.""" + if func.__name__ == 'clip': + return lambda x: func(x, arr[0], arr[1]) + return func + + +def is_view(func, a, value): + """Apply `func`, mutate the output; does the input change?""" + b = func(a) + b[0] = value + return a[0] == value + + +@pytest.mark.parametrize('xp_name', LIB_NAMES) +@pytest.mark.parametrize('inputs', FUNC_INPUTS, ids=[inp[0] for inp in FUNC_INPUTS]) +def test_view_or_copy(inputs, xp_name): + bare_xp = import_(xp_name, wrapper=False) + wrapped_xp = import_(xp_name, wrapper=True) + + func_name, arr_input, dtype_str, value = inputs + dtype = getattr(bare_xp, dtype_str) + + bare_func = getattr(bare_xp, func_name) + bare_func = ensure_unary(bare_func, arr_input) + + wrapped_func = getattr(wrapped_xp, func_name) + wrapped_func = ensure_unary(wrapped_func, arr_input) + + # bare namespace: mutate the output, does the input change? + a = bare_xp.asarray(arr_input, dtype=dtype) + is_view_bare = is_view(bare_func, a, value) + + # wrapped namespace: mutate the output, does the input change? + a1 = wrapped_xp.asarray(arr_input, dtype=dtype) + is_view_wrapped = is_view(wrapped_func, a1, value) + + assert is_view_bare == is_view_wrapped From 1facc3526414926b2d123e88c16f7d517d9d2558 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sun, 1 Jun 2025 14:30:54 +0200 Subject: [PATCH 071/151] BUG: make ceil,trunc,floor always respect view/copy semantics Remove these functions from common/_aliases.py, add specific implementations for numpy < 2 and cupy. --- array_api_compat/common/_aliases.py | 24 -------------------- array_api_compat/cupy/_aliases.py | 24 ++++++++++++++++---- array_api_compat/dask/array/_aliases.py | 3 --- array_api_compat/numpy/_aliases.py | 29 ++++++++++++++++++++++--- 4 files changed, 46 insertions(+), 34 deletions(-) diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 8ea9162a..39d10860 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -524,27 +524,6 @@ def nonzero(x: Array, /, xp: Namespace, **kwargs: object) -> tuple[Array, ...]: return xp.nonzero(x, **kwargs) -# ceil, floor, and trunc return integers for integer inputs - - -def ceil(x: Array, /, xp: Namespace, **kwargs: object) -> Array: - if xp.issubdtype(x.dtype, xp.integer): - return x - return xp.ceil(x, **kwargs) - - -def floor(x: Array, /, xp: Namespace, **kwargs: object) -> Array: - if xp.issubdtype(x.dtype, xp.integer): - return x - return xp.floor(x, **kwargs) - - -def trunc(x: Array, /, xp: Namespace, **kwargs: object) -> Array: - if xp.issubdtype(x.dtype, xp.integer): - return x - return xp.trunc(x, **kwargs) - - # linear algebra functions @@ -707,9 +686,6 @@ def iinfo(type_: DType | Array, /, xp: Namespace) -> Any: "argsort", "sort", "nonzero", - "ceil", - "floor", - "trunc", "matmul", "matrix_transpose", "tensordot", diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index 90b48f05..e000602e 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -54,9 +54,6 @@ argsort = get_xp(cp)(_aliases.argsort) sort = get_xp(cp)(_aliases.sort) nonzero = get_xp(cp)(_aliases.nonzero) -ceil = get_xp(cp)(_aliases.ceil) -floor = get_xp(cp)(_aliases.floor) -trunc = get_xp(cp)(_aliases.trunc) matmul = get_xp(cp)(_aliases.matmul) matrix_transpose = get_xp(cp)(_aliases.matrix_transpose) tensordot = get_xp(cp)(_aliases.tensordot) @@ -123,6 +120,25 @@ def count_nonzero( return cp.expand_dims(result, axis) return result +# ceil, floor, and trunc return integers for integer inputs + +def ceil(x: Array, /) -> Array: + if cp.issubdtype(x.dtype, cp.integer): + return x.copy() + return cp.ceil(x) + + +def floor(x: Array, /) -> Array: + if cp.issubdtype(x.dtype, cp.integer): + return x.copy() + return cp.floor(x) + + +def trunc(x: Array, /) -> Array: + if cp.issubdtype(x.dtype, cp.integer): + return x.copy() + return cp.trunc(x) + # take_along_axis: axis defaults to -1 but in cupy (and numpy) axis is a required arg def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1): @@ -151,6 +167,6 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1): 'atan2', 'atanh', 'bitwise_left_shift', 'bitwise_invert', 'bitwise_right_shift', 'bool', 'concat', 'count_nonzero', 'pow', 'sign', - 'take_along_axis'] + 'ceil', 'floor', 'trunc', 'take_along_axis'] _all_ignore = ['cp', 'get_xp'] diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index d43881ab..0bb5d227 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -134,9 +134,6 @@ def arange( matrix_transpose = get_xp(da)(_aliases.matrix_transpose) vecdot = get_xp(da)(_aliases.vecdot) nonzero = get_xp(da)(_aliases.nonzero) -ceil = get_xp(np)(_aliases.ceil) -floor = get_xp(np)(_aliases.floor) -trunc = get_xp(np)(_aliases.trunc) matmul = get_xp(np)(_aliases.matmul) tensordot = get_xp(np)(_aliases.tensordot) sign = get_xp(np)(_aliases.sign) diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index a1aee5c0..502dfb3a 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -63,9 +63,6 @@ argsort = get_xp(np)(_aliases.argsort) sort = get_xp(np)(_aliases.sort) nonzero = get_xp(np)(_aliases.nonzero) -ceil = get_xp(np)(_aliases.ceil) -floor = get_xp(np)(_aliases.floor) -trunc = get_xp(np)(_aliases.trunc) matmul = get_xp(np)(_aliases.matmul) matrix_transpose = get_xp(np)(_aliases.matrix_transpose) tensordot = get_xp(np)(_aliases.tensordot) @@ -145,6 +142,29 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1): return np.take_along_axis(x, indices, axis=axis) +# ceil, floor, and trunc return integers for integer inputs in NumPy < 2 + +def ceil(x: Array, /) -> Array: + if np.issubdtype(x.dtype, np.integer): + if np.__version__ < '2': + return x.copy() + return np.ceil(x) + + +def floor(x: Array, /) -> Array: + if np.issubdtype(x.dtype, np.integer): + if np.__version__ < '2': + return x.copy() + return np.floor(x) + + +def trunc(x: Array, /) -> Array: + if np.issubdtype(x.dtype, np.integer): + if np.__version__ < '2': + return x.copy() + return np.trunc(x) + + # These functions are completely new here. If the library already has them # (i.e., numpy 2.0), use the library version instead of our wrapper. if hasattr(np, "vecdot"): @@ -173,6 +193,9 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1): "atan", "atan2", "atanh", + "ceil", + "floor", + "trunc", "bitwise_left_shift", "bitwise_invert", "bitwise_right_shift", From 0ad664bdfde03ec3f21d82b1048616ae5d0fb6b7 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Wed, 4 Jun 2025 10:49:43 +0200 Subject: [PATCH 072/151] Apply suggestions from code review Co-authored-by: Guido Imperiale --- tests/test_copies_or_views.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_copies_or_views.py b/tests/test_copies_or_views.py index 5b9b9207..24d03547 100644 --- a/tests/test_copies_or_views.py +++ b/tests/test_copies_or_views.py @@ -3,7 +3,7 @@ on whether to return a view or a copy of inputs. """ import pytest -from ._helpers import import_ +from ._helpers import import_, wrapped_libraries LIB_NAMES = ['numpy', 'cupy', 'torch', 'dask.array', 'array_api_strict'] @@ -40,7 +40,7 @@ def is_view(func, a, value): return a[0] == value -@pytest.mark.parametrize('xp_name', LIB_NAMES) +@pytest.mark.parametrize('xp_name', wrapped_libraries) @pytest.mark.parametrize('inputs', FUNC_INPUTS, ids=[inp[0] for inp in FUNC_INPUTS]) def test_view_or_copy(inputs, xp_name): bare_xp = import_(xp_name, wrapper=False) From 118ae2d0428be763abf1e31b2827a4800398e901 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Wed, 4 Jun 2025 10:52:46 +0200 Subject: [PATCH 073/151] TST: test views vs copies on array-api-strict, too --- tests/test_copies_or_views.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/test_copies_or_views.py b/tests/test_copies_or_views.py index 24d03547..ec8995f7 100644 --- a/tests/test_copies_or_views.py +++ b/tests/test_copies_or_views.py @@ -6,8 +6,6 @@ from ._helpers import import_, wrapped_libraries -LIB_NAMES = ['numpy', 'cupy', 'torch', 'dask.array', 'array_api_strict'] - FUNC_INPUTS = [ # func_name, arr_input, dtype, scalar_value ('abs', [1, 2], 'int8', 3), @@ -40,7 +38,7 @@ def is_view(func, a, value): return a[0] == value -@pytest.mark.parametrize('xp_name', wrapped_libraries) +@pytest.mark.parametrize('xp_name', wrapped_libraries + ['array_api_strict']) @pytest.mark.parametrize('inputs', FUNC_INPUTS, ids=[inp[0] for inp in FUNC_INPUTS]) def test_view_or_copy(inputs, xp_name): bare_xp = import_(xp_name, wrapper=False) From b0eed557d6dba8c87d9693ff82360b33c1af3480 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Wed, 4 Jun 2025 13:05:08 +0200 Subject: [PATCH 074/151] Apply suggestions from code review Co-authored-by: Guido Imperiale --- array_api_compat/numpy/_aliases.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index 502dfb3a..f04837de 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -145,23 +145,20 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1): # ceil, floor, and trunc return integers for integer inputs in NumPy < 2 def ceil(x: Array, /) -> Array: - if np.issubdtype(x.dtype, np.integer): - if np.__version__ < '2': - return x.copy() + if np.__version__ < '2' and np.issubdtype(x.dtype, np.integer): + return x.copy() return np.ceil(x) def floor(x: Array, /) -> Array: - if np.issubdtype(x.dtype, np.integer): - if np.__version__ < '2': - return x.copy() + if np.__version__ < '2' and np.issubdtype(x.dtype, np.integer): + return x.copy() return np.floor(x) def trunc(x: Array, /) -> Array: - if np.issubdtype(x.dtype, np.integer): - if np.__version__ < '2': - return x.copy() + if np.__version__ < '2' and np.issubdtype(x.dtype, np.integer): + return x.copy() return np.trunc(x) From 2b559e62e05ebea3dd3ab631aee47b270109eaa1 Mon Sep 17 00:00:00 2001 From: Guido Imperiale Date: Wed, 4 Jun 2025 15:36:16 +0100 Subject: [PATCH 075/151] TYP: Type annotations, part 4 (#313) * Type annotations, part 4 * Fix CopyMode * revert * Revert `_all_ignore` * code review * code review * JustInt mypy ignores * lint * fix merge * lint * Reverts and tweaks * Fix test_all * Revert batmobile --- array_api_compat/_internal.py | 4 +- array_api_compat/common/_aliases.py | 13 +- array_api_compat/common/_helpers.py | 34 ++--- array_api_compat/common/_linalg.py | 6 +- array_api_compat/common/_typing.py | 15 +- array_api_compat/cupy/_aliases.py | 31 ++-- array_api_compat/cupy/fft.py | 9 +- array_api_compat/cupy/linalg.py | 2 +- array_api_compat/dask/array/__init__.py | 2 +- array_api_compat/dask/array/_aliases.py | 2 +- array_api_compat/dask/array/_info.py | 47 +++--- array_api_compat/dask/array/fft.py | 2 +- array_api_compat/dask/array/linalg.py | 19 +-- array_api_compat/numpy/__init__.py | 2 +- array_api_compat/numpy/_aliases.py | 33 ++-- array_api_compat/numpy/_info.py | 3 +- array_api_compat/numpy/linalg.py | 4 +- array_api_compat/torch/_aliases.py | 192 ++++++++++++------------ array_api_compat/torch/fft.py | 19 +-- array_api_compat/torch/linalg.py | 12 +- pyproject.toml | 56 ++++--- 21 files changed, 246 insertions(+), 261 deletions(-) diff --git a/array_api_compat/_internal.py b/array_api_compat/_internal.py index cd8d939f..b1925492 100644 --- a/array_api_compat/_internal.py +++ b/array_api_compat/_internal.py @@ -46,8 +46,8 @@ def wrapped_f(*args: object, **kwargs: object) -> object: specification for more details. """ - wrapped_f.__signature__ = new_sig # pyright: ignore[reportAttributeAccessIssue] - return wrapped_f # pyright: ignore[reportReturnType] + wrapped_f.__signature__ = new_sig # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] + return wrapped_f # type: ignore[return-value] # pyright: ignore[reportReturnType] return inner diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 8ea9162a..51732b91 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -5,11 +5,12 @@ from __future__ import annotations import inspect -from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Sequence, cast +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any, NamedTuple, cast from ._helpers import _check_device, array_namespace from ._helpers import device as _get_device -from ._helpers import is_cupy_namespace as _is_cupy_namespace +from ._helpers import is_cupy_namespace from ._typing import Array, Device, DType, Namespace if TYPE_CHECKING: @@ -381,8 +382,8 @@ def clip( # TODO: np.clip has other ufunc kwargs out: Array | None = None, ) -> Array: - def _isscalar(a: object) -> TypeIs[int | float | None]: - return isinstance(a, (int, float, type(None))) + def _isscalar(a: object) -> TypeIs[float | None]: + return isinstance(a, int | float) or a is None min_shape = () if _isscalar(min) else min.shape max_shape = () if _isscalar(max) else max.shape @@ -450,7 +451,7 @@ def reshape( shape: tuple[int, ...], xp: Namespace, *, - copy: Optional[bool] = None, + copy: bool | None = None, **kwargs: object, ) -> Array: if copy is True: @@ -657,7 +658,7 @@ def sign(x: Array, /, xp: Namespace, **kwargs: object) -> Array: out = xp.sign(x, **kwargs) # CuPy sign() does not propagate nans. See # https://github.com/data-apis/array-api-compat/issues/136 - if _is_cupy_namespace(xp) and isdtype(x.dtype, "real floating", xp=xp): + if is_cupy_namespace(xp) and isdtype(x.dtype, "real floating", xp=xp): out[xp.isnan(x)] = xp.nan return out[()] diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index a152e4c0..cae0ee0b 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -23,7 +23,6 @@ SupportsIndex, TypeAlias, TypeGuard, - TypeVar, cast, overload, ) @@ -31,32 +30,29 @@ from ._typing import Array, Device, HasShape, Namespace, SupportsArrayNamespace if TYPE_CHECKING: - + import cupy as cp import dask.array as da import jax import ndonnx as ndx import numpy as np import numpy.typing as npt - import sparse # pyright: ignore[reportMissingTypeStubs] + import sparse import torch # TODO: import from typing (requires Python >=3.13) - from typing_extensions import TypeIs, TypeVar - - _SizeT = TypeVar("_SizeT", bound = int | None) + from typing_extensions import TypeIs _ZeroGradientArray: TypeAlias = npt.NDArray[np.void] - _CupyArray: TypeAlias = Any # cupy has no py.typed _ArrayApiObj: TypeAlias = ( npt.NDArray[Any] + | cp.ndarray | da.Array | jax.Array | ndx.Array | sparse.SparseArray | torch.Tensor | SupportsArrayNamespace[Any] - | _CupyArray ) _API_VERSIONS_OLD: Final = frozenset({"2021.12", "2022.12", "2023.12"}) @@ -96,7 +92,7 @@ def _is_jax_zero_gradient_array(x: object) -> TypeGuard[_ZeroGradientArray]: return dtype == jax.float0 -def is_numpy_array(x: object) -> TypeGuard[npt.NDArray[Any]]: +def is_numpy_array(x: object) -> TypeIs[npt.NDArray[Any]]: """ Return True if `x` is a NumPy array. @@ -267,7 +263,7 @@ def is_pydata_sparse_array(x: object) -> TypeIs[sparse.SparseArray]: return _issubclass_fast(cls, "sparse", "SparseArray") -def is_array_api_obj(x: object) -> TypeIs[_ArrayApiObj]: # pyright: ignore[reportUnknownParameterType] +def is_array_api_obj(x: object) -> TypeGuard[_ArrayApiObj]: """ Return True if `x` is an array API compatible array object. @@ -748,7 +744,7 @@ def device(x: _ArrayApiObj, /) -> Device: return "cpu" elif is_dask_array(x): # Peek at the metadata of the Dask array to determine type - if is_numpy_array(x._meta): # pyright: ignore + if is_numpy_array(x._meta): # Must be on CPU since backed by numpy return "cpu" return _DASK_DEVICE @@ -777,7 +773,7 @@ def device(x: _ArrayApiObj, /) -> Device: return "cpu" # Return the device of the constituent array return device(inner) # pyright: ignore - return x.device # pyright: ignore + return x.device # type: ignore # pyright: ignore # Prevent shadowing, used below @@ -786,11 +782,11 @@ def device(x: _ArrayApiObj, /) -> Device: # Based on cupy.array_api.Array.to_device def _cupy_to_device( - x: _CupyArray, + x: cp.ndarray, device: Device, /, stream: int | Any | None = None, -) -> _CupyArray: +) -> cp.ndarray: import cupy as cp if device == "cpu": @@ -819,7 +815,7 @@ def _torch_to_device( x: torch.Tensor, device: torch.device | str | int, /, - stream: None = None, + stream: int | Any | None = None, ) -> torch.Tensor: if stream is not None: raise NotImplementedError @@ -885,7 +881,7 @@ def to_device(x: Array, device: Device, /, *, stream: int | Any | None = None) - # cupy does not yet have to_device return _cupy_to_device(x, device, stream=stream) elif is_torch_array(x): - return _torch_to_device(x, device, stream=stream) # pyright: ignore[reportArgumentType] + return _torch_to_device(x, device, stream=stream) elif is_dask_array(x): if stream is not None: raise ValueError("The stream argument to to_device() is not supported") @@ -912,8 +908,6 @@ def to_device(x: Array, device: Device, /, *, stream: int | Any | None = None) - @overload def size(x: HasShape[Collection[SupportsIndex]]) -> int: ... @overload -def size(x: HasShape[Collection[None]]) -> None: ... -@overload def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None: ... def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None: """ @@ -948,7 +942,7 @@ def _is_writeable_cls(cls: type) -> bool | None: return None -def is_writeable_array(x: object) -> bool: +def is_writeable_array(x: object) -> TypeGuard[_ArrayApiObj]: """ Return False if ``x.__setitem__`` is expected to raise; True otherwise. Return False if `x` is not an array API compatible object. @@ -986,7 +980,7 @@ def _is_lazy_cls(cls: type) -> bool | None: return None -def is_lazy_array(x: object) -> bool: +def is_lazy_array(x: object) -> TypeGuard[_ArrayApiObj]: """Return True if x is potentially a future or it may be otherwise impossible or expensive to eagerly read its contents, regardless of their size, e.g. by calling ``bool(x)`` or ``float(x)``. diff --git a/array_api_compat/common/_linalg.py b/array_api_compat/common/_linalg.py index 7ad87a1b..3fd9d860 100644 --- a/array_api_compat/common/_linalg.py +++ b/array_api_compat/common/_linalg.py @@ -8,7 +8,7 @@ if np.__version__[0] == "2": from numpy.lib.array_utils import normalize_axis_tuple else: - from numpy.core.numeric import normalize_axis_tuple + from numpy.core.numeric import normalize_axis_tuple # type: ignore[no-redef] from .._internal import get_xp from ._aliases import isdtype, matmul, matrix_transpose, tensordot, vecdot @@ -187,14 +187,14 @@ def vector_norm( # We can't reuse xp.linalg.norm(keepdims) because of the reshape hacks # above to avoid matrix norm logic. shape = list(x.shape) - _axis = cast( + axes = cast( "tuple[int, ...]", normalize_axis_tuple( # pyright: ignore[reportCallIssue] range(x.ndim) if axis is None else axis, x.ndim, ), ) - for i in _axis: + for i in axes: shape[i] = 1 res = xp.reshape(res, tuple(shape)) diff --git a/array_api_compat/common/_typing.py b/array_api_compat/common/_typing.py index cd26feeb..11b00bd1 100644 --- a/array_api_compat/common/_typing.py +++ b/array_api_compat/common/_typing.py @@ -34,32 +34,29 @@ # - docs: https://github.com/jorenham/optype/blob/master/README.md#just # - code: https://github.com/jorenham/optype/blob/master/optype/_core/_just.py @final -class JustInt(Protocol): - @property +class JustInt(Protocol): # type: ignore[misc] + @property # type: ignore[override] def __class__(self, /) -> type[int]: ... @__class__.setter def __class__(self, value: type[int], /) -> None: ... # pyright: ignore[reportIncompatibleMethodOverride] @final -class JustFloat(Protocol): - @property +class JustFloat(Protocol): # type: ignore[misc] + @property # type: ignore[override] def __class__(self, /) -> type[float]: ... @__class__.setter def __class__(self, value: type[float], /) -> None: ... # pyright: ignore[reportIncompatibleMethodOverride] @final -class JustComplex(Protocol): - @property +class JustComplex(Protocol): # type: ignore[misc] + @property # type: ignore[override] def __class__(self, /) -> type[complex]: ... @__class__.setter def __class__(self, value: type[complex], /) -> None: ... # pyright: ignore[reportIncompatibleMethodOverride] -# - - class NestedSequence(Protocol[_T_co]): def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ... def __len__(self, /) -> int: ... diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index 90b48f05..c0473ca4 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Optional +from builtins import bool as py_bool import cupy as cp @@ -67,18 +67,13 @@ # asarray also adds the copy keyword, which is not present in numpy 1.0. def asarray( - obj: ( - Array - | bool | int | float | complex - | NestedSequence[bool | int | float | complex] - | SupportsBufferProtocol - ), + obj: Array | complex | NestedSequence[complex] | SupportsBufferProtocol, /, *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - copy: Optional[bool] = None, - **kwargs, + dtype: DType | None = None, + device: Device | None = None, + copy: py_bool | None = None, + **kwargs: object, ) -> Array: """ Array API compatibility wrapper for asarray(). @@ -101,8 +96,8 @@ def astype( dtype: DType, /, *, - copy: bool = True, - device: Optional[Device] = None, + copy: py_bool = True, + device: Device | None = None, ) -> Array: if device is None: return x.astype(dtype=dtype, copy=copy) @@ -113,8 +108,8 @@ def astype( # cupy.count_nonzero does not have keepdims def count_nonzero( x: Array, - axis=None, - keepdims=False + axis: int | tuple[int, ...] | None = None, + keepdims: py_bool = False, ) -> Array: result = cp.count_nonzero(x, axis) if keepdims: @@ -125,7 +120,7 @@ def count_nonzero( # take_along_axis: axis defaults to -1 but in cupy (and numpy) axis is a required arg -def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1): +def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array: return cp.take_along_axis(x, indices, axis=axis) @@ -153,4 +148,6 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1): 'bool', 'concat', 'count_nonzero', 'pow', 'sign', 'take_along_axis'] -_all_ignore = ['cp', 'get_xp'] + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/cupy/fft.py b/array_api_compat/cupy/fft.py index 307e0f72..2bd11940 100644 --- a/array_api_compat/cupy/fft.py +++ b/array_api_compat/cupy/fft.py @@ -1,10 +1,11 @@ -from cupy.fft import * # noqa: F403 +from cupy.fft import * # noqa: F403 + # cupy.fft doesn't have __all__. If it is added, replace this with # # from cupy.fft import __all__ as linalg_all -_n = {} -exec('from cupy.fft import *', _n) -del _n['__builtins__'] +_n: dict[str, object] = {} +exec("from cupy.fft import *", _n) +del _n["__builtins__"] fft_all = list(_n) del _n diff --git a/array_api_compat/cupy/linalg.py b/array_api_compat/cupy/linalg.py index 7fcdd498..7bc3536e 100644 --- a/array_api_compat/cupy/linalg.py +++ b/array_api_compat/cupy/linalg.py @@ -2,7 +2,7 @@ # cupy.linalg doesn't have __all__. If it is added, replace this with # # from cupy.linalg import __all__ as linalg_all -_n = {} +_n: dict[str, object] = {} exec('from cupy.linalg import *', _n) del _n['__builtins__'] linalg_all = list(_n) diff --git a/array_api_compat/dask/array/__init__.py b/array_api_compat/dask/array/__init__.py index 1e47b960..6d2ea7cd 100644 --- a/array_api_compat/dask/array/__init__.py +++ b/array_api_compat/dask/array/__init__.py @@ -3,7 +3,7 @@ from dask.array import * # noqa: F403 # These imports may overwrite names from the import * above. -from ._aliases import * # noqa: F403 +from ._aliases import * # type: ignore[assignment] # noqa: F403 __array_api_version__: Final = "2024.12" diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index d43881ab..bc0302fe 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -146,7 +146,7 @@ def arange( # asarray also adds the copy keyword, which is not present in numpy 1.0. def asarray( - obj: complex | NestedSequence[complex] | Array | SupportsBufferProtocol, + obj: Array | complex | NestedSequence[complex] | SupportsBufferProtocol, /, *, dtype: DType | None = None, diff --git a/array_api_compat/dask/array/_info.py b/array_api_compat/dask/array/_info.py index 9e4d736f..2f39fc4b 100644 --- a/array_api_compat/dask/array/_info.py +++ b/array_api_compat/dask/array/_info.py @@ -12,9 +12,9 @@ from __future__ import annotations -from typing import Literal as L -from typing import TypeAlias, overload +from typing import Literal, TypeAlias, overload +import dask.array as da from numpy import bool_ as bool from numpy import ( complex64, @@ -33,7 +33,7 @@ uint64, ) -from ...common._helpers import _DASK_DEVICE, _dask_device +from ...common._helpers import _DASK_DEVICE, _check_device, _dask_device from ...common._typing import ( Capabilities, DefaultDTypes, @@ -49,8 +49,7 @@ DTypesSigned, DTypesUnsigned, ) - -_Device: TypeAlias = L["cpu"] | _dask_device +Device: TypeAlias = Literal["cpu"] | _dask_device class __array_namespace_info__: @@ -142,7 +141,7 @@ def capabilities(self) -> Capabilities: "max dimensions": 64, } - def default_device(self) -> L["cpu"]: + def default_device(self) -> Device: """ The default device used for new Dask arrays. @@ -169,7 +168,7 @@ def default_device(self) -> L["cpu"]: """ return "cpu" - def default_dtypes(self, /, *, device: _Device | None = None) -> DefaultDTypes: + def default_dtypes(self, /, *, device: Device | None = None) -> DefaultDTypes: """ The default data types used for new Dask arrays. @@ -208,11 +207,7 @@ def default_dtypes(self, /, *, device: _Device | None = None) -> DefaultDTypes: 'indexing': dask.int64} """ - if device not in ["cpu", _DASK_DEVICE, None]: - raise ValueError( - f'Device not understood. Only "cpu" or _DASK_DEVICE is allowed, ' - f"but received: {device!r}" - ) + _check_device(da, device) return { "real floating": dtype(float64), "complex floating": dtype(complex128), @@ -222,38 +217,38 @@ def default_dtypes(self, /, *, device: _Device | None = None) -> DefaultDTypes: @overload def dtypes( - self, /, *, device: _Device | None = None, kind: None = None + self, /, *, device: Device | None = None, kind: None = None ) -> DTypesAll: ... @overload def dtypes( - self, /, *, device: _Device | None = None, kind: L["bool"] + self, /, *, device: Device | None = None, kind: Literal["bool"] ) -> DTypesBool: ... @overload def dtypes( - self, /, *, device: _Device | None = None, kind: L["signed integer"] + self, /, *, device: Device | None = None, kind: Literal["signed integer"] ) -> DTypesSigned: ... @overload def dtypes( - self, /, *, device: _Device | None = None, kind: L["unsigned integer"] + self, /, *, device: Device | None = None, kind: Literal["unsigned integer"] ) -> DTypesUnsigned: ... @overload def dtypes( - self, /, *, device: _Device | None = None, kind: L["integral"] + self, /, *, device: Device | None = None, kind: Literal["integral"] ) -> DTypesIntegral: ... @overload def dtypes( - self, /, *, device: _Device | None = None, kind: L["real floating"] + self, /, *, device: Device | None = None, kind: Literal["real floating"] ) -> DTypesReal: ... @overload def dtypes( - self, /, *, device: _Device | None = None, kind: L["complex floating"] + self, /, *, device: Device | None = None, kind: Literal["complex floating"] ) -> DTypesComplex: ... @overload def dtypes( - self, /, *, device: _Device | None = None, kind: L["numeric"] + self, /, *, device: Device | None = None, kind: Literal["numeric"] ) -> DTypesNumeric: ... def dtypes( - self, /, *, device: _Device | None = None, kind: DTypeKind | None = None + self, /, *, device: Device | None = None, kind: DTypeKind | None = None ) -> DTypesAny: """ The array API data types supported by Dask. @@ -308,11 +303,7 @@ def dtypes( 'int64': dask.int64} """ - if device not in ["cpu", _DASK_DEVICE, None]: - raise ValueError( - 'Device not understood. Only "cpu" or _DASK_DEVICE is allowed, but received:' - f" {device}" - ) + _check_device(da, device) if kind is None: return { "bool": dtype(bool), @@ -381,14 +372,14 @@ def dtypes( "complex64": dtype(complex64), "complex128": dtype(complex128), } - if isinstance(kind, tuple): # type: ignore[reportUnnecessaryIsinstanceCall] + if isinstance(kind, tuple): res: dict[str, DType] = {} for k in kind: res.update(self.dtypes(kind=k)) return res raise ValueError(f"unsupported kind: {kind!r}") - def devices(self) -> list[_Device]: + def devices(self) -> list[Device]: """ The devices supported by Dask. diff --git a/array_api_compat/dask/array/fft.py b/array_api_compat/dask/array/fft.py index 3f40dffe..68c4280e 100644 --- a/array_api_compat/dask/array/fft.py +++ b/array_api_compat/dask/array/fft.py @@ -2,7 +2,7 @@ # dask.array.fft doesn't have __all__. If it is added, replace this with # # from dask.array.fft import __all__ as linalg_all -_n = {} +_n: dict[str, object] = {} exec('from dask.array.fft import *', _n) for k in ("__builtins__", "Sequence", "annotations", "warnings"): _n.pop(k, None) diff --git a/array_api_compat/dask/array/linalg.py b/array_api_compat/dask/array/linalg.py index 0825386e..06f596bc 100644 --- a/array_api_compat/dask/array/linalg.py +++ b/array_api_compat/dask/array/linalg.py @@ -4,21 +4,22 @@ import dask.array as da -# The `matmul` and `tensordot` functions are in both the main and linalg namespaces -from dask.array import matmul, outer, tensordot - # Exports from dask.array.linalg import * # noqa: F403 +from dask.array import outer +# The `matmul` and `tensordot` functions are in both the main and linalg namespaces +from dask.array import matmul, tensordot + from ..._internal import get_xp from ...common import _linalg -from ...common._typing import Array as _Array +from ...common._typing import Array from ._aliases import matrix_transpose, vecdot # dask.array.linalg doesn't have __all__. If it is added, replace this with # # from dask.array.linalg import __all__ as linalg_all -_n = {} +_n: dict[str, object] = {} exec('from dask.array.linalg import *', _n) for k in ('__builtins__', 'annotations', 'operator', 'warnings', 'Array'): _n.pop(k, None) @@ -33,8 +34,8 @@ # supports the mode keyword on QR # https://github.com/dask/dask/issues/10388 #qr = get_xp(da)(_linalg.qr) -def qr( - x: _Array, +def qr( # type: ignore[no-redef] + x: Array, mode: Literal["reduced", "complete"] = "reduced", **kwargs: object, ) -> QRResult: @@ -50,12 +51,12 @@ def qr( # Wrap the svd functions to not pass full_matrices to dask # when full_matrices=False (as that is the default behavior for dask), # and dask doesn't have the full_matrices keyword -def svd(x: _Array, full_matrices: bool = True, **kwargs) -> SVDResult: +def svd(x: Array, full_matrices: bool = True, **kwargs: object) -> SVDResult: # type: ignore[no-redef] if full_matrices: raise ValueError("full_matrics=True is not supported by dask.") return da.linalg.svd(x, coerce_signs=False, **kwargs) -def svdvals(x: _Array) -> _Array: +def svdvals(x: Array) -> Array: # TODO: can't avoid computing U or V for dask _, s, _ = svd(x) return s diff --git a/array_api_compat/numpy/__init__.py b/array_api_compat/numpy/__init__.py index 3e138f53..bf43fe61 100644 --- a/array_api_compat/numpy/__init__.py +++ b/array_api_compat/numpy/__init__.py @@ -10,7 +10,7 @@ from numpy import round as round # These imports may overwrite names from the import * above. -from ._aliases import * # noqa: F403 +from ._aliases import * # type: ignore[assignment,no-redef] # noqa: F403 # Don't know why, but we have to do an absolute import to import linalg. If we # instead do diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index a1aee5c0..5a05a820 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -2,7 +2,7 @@ from __future__ import annotations from builtins import bool as py_bool -from typing import TYPE_CHECKING, Any, Literal, TypeAlias, cast +from typing import Any, cast import numpy as np @@ -12,13 +12,6 @@ from ._info import __array_namespace_info__ from ._typing import Array, Device, DType -if TYPE_CHECKING: - from typing_extensions import Buffer, TypeIs - -# The values of the `_CopyMode` enum can be either `False`, `True`, or `2`: -# https://github.com/numpy/numpy/blob/5a8a6a79d9c2fff8f07dcab5d41e14f8508d673f/numpy/_globals.pyi#L7-L10 -_Copy: TypeAlias = py_bool | Literal[2] | np._CopyMode - bool = np.bool_ # Basic renames @@ -74,14 +67,6 @@ iinfo = get_xp(np)(_aliases.iinfo) -def _supports_buffer_protocol(obj: object) -> TypeIs[Buffer]: # pyright: ignore[reportUnusedFunction] - try: - memoryview(obj) # pyright: ignore[reportArgumentType] - except TypeError: - return False - return True - - # asarray also adds the copy keyword, which is not present in numpy 1.0. # asarray() is different enough between numpy, cupy, and dask, the logic # complicated enough that it's easier to define it separately for each module @@ -92,7 +77,7 @@ def asarray( *, dtype: DType | None = None, device: Device | None = None, - copy: _Copy | None = None, + copy: py_bool | None = None, **kwargs: Any, ) -> Array: """ @@ -103,14 +88,14 @@ def asarray( """ _helpers._check_device(np, device) + # None is unsupported in NumPy 1.0, but we can use an internal enum + # False in NumPy 1.0 means None in NumPy 2.0 and in the Array API if copy is None: - copy = np._CopyMode.IF_NEEDED + copy = np._CopyMode.IF_NEEDED # type: ignore[assignment,attr-defined] elif copy is False: - copy = np._CopyMode.NEVER - elif copy is True: - copy = np._CopyMode.ALWAYS + copy = np._CopyMode.NEVER # type: ignore[assignment,attr-defined] - return np.array(obj, copy=copy, dtype=dtype, **kwargs) # pyright: ignore + return np.array(obj, copy=copy, dtype=dtype, **kwargs) def astype( @@ -141,7 +126,7 @@ def count_nonzero( # take_along_axis: axis defaults to -1 but in numpy axis is a required arg -def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1): +def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array: return np.take_along_axis(x, indices, axis=axis) @@ -150,7 +135,7 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1): if hasattr(np, "vecdot"): vecdot = np.vecdot else: - vecdot = get_xp(np)(_aliases.vecdot) + vecdot = get_xp(np)(_aliases.vecdot) # type: ignore[assignment] if hasattr(np, "isdtype"): isdtype = np.isdtype diff --git a/array_api_compat/numpy/_info.py b/array_api_compat/numpy/_info.py index f307f62c..c625c13e 100644 --- a/array_api_compat/numpy/_info.py +++ b/array_api_compat/numpy/_info.py @@ -27,6 +27,7 @@ uint64, ) +from ..common._typing import DefaultDTypes from ._typing import Device, DType @@ -139,7 +140,7 @@ def default_dtypes( self, *, device: Device | None = None, - ) -> dict[str, dtype[intp | float64 | complex128]]: + ) -> DefaultDTypes: """ The default data types used for new NumPy arrays. diff --git a/array_api_compat/numpy/linalg.py b/array_api_compat/numpy/linalg.py index 2d3e731d..9a618be9 100644 --- a/array_api_compat/numpy/linalg.py +++ b/array_api_compat/numpy/linalg.py @@ -65,7 +65,7 @@ # https://github.com/cupy/cupy/blob/main/cupy/cublas.py#L43). def solve(x1: Array, x2: Array, /) -> Array: try: - from numpy.linalg._linalg import ( + from numpy.linalg._linalg import ( # type: ignore[attr-defined] _assert_stacked_2d, _assert_stacked_square, _commonType, @@ -74,7 +74,7 @@ def solve(x1: Array, x2: Array, /) -> Array: isComplexType, ) except ImportError: - from numpy.linalg.linalg import ( + from numpy.linalg.linalg import ( # type: ignore[attr-defined] _assert_stacked_2d, _assert_stacked_square, _commonType, diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index de5d1a5d..7a449001 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -1,8 +1,9 @@ from __future__ import annotations +from collections.abc import Sequence from functools import reduce as _reduce, wraps as _wraps from builtins import all as _builtin_all, any as _builtin_any -from typing import Any, List, Optional, Sequence, Tuple, Union, Literal +from typing import Any, Literal import torch @@ -96,9 +97,7 @@ def _fix_promotion(x1, x2, only_scalar=True): _py_scalars = (bool, int, float, complex) -def result_type( - *arrays_and_dtypes: Array | DType | bool | int | float | complex -) -> DType: +def result_type(*arrays_and_dtypes: Array | DType | complex) -> DType: num = len(arrays_and_dtypes) if num == 0: @@ -129,10 +128,7 @@ def result_type( return _reduce(_result_type, others + scalars) -def _result_type( - x: Array | DType | bool | int | float | complex, - y: Array | DType | bool | int | float | complex, -) -> DType: +def _result_type(x: Array | DType | complex, y: Array | DType | complex) -> DType: if not (isinstance(x, _py_scalars) or isinstance(y, _py_scalars)): xdt = x if isinstance(x, torch.dtype) else x.dtype ydt = y if isinstance(y, torch.dtype) else y.dtype @@ -150,7 +146,7 @@ def _result_type( return torch.result_type(x, y) -def can_cast(from_: Union[DType, Array], to: DType, /) -> bool: +def can_cast(from_: DType | Array, to: DType, /) -> bool: if not isinstance(from_, torch.dtype): from_ = from_.dtype return torch.can_cast(from_, to) @@ -194,12 +190,7 @@ def can_cast(from_: Union[DType, Array], to: DType, /) -> bool: def asarray( - obj: ( - Array - | bool | int | float | complex - | NestedSequence[bool | int | float | complex] - | SupportsBufferProtocol - ), + obj: Array | complex | NestedSequence[complex] | SupportsBufferProtocol, /, *, dtype: DType | None = None, @@ -218,13 +209,13 @@ def asarray( # of 'axis'. # torch.min and torch.max return a tuple and don't support multiple axes https://github.com/pytorch/pytorch/issues/58745 -def max(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Array: +def max(x: Array, /, *, axis: int | tuple[int, ...] | None = None, keepdims: bool = False) -> Array: # https://github.com/pytorch/pytorch/issues/29137 if axis == (): return torch.clone(x) return torch.amax(x, axis, keepdims=keepdims) -def min(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Array: +def min(x: Array, /, *, axis: int | tuple[int, ...] |None = None, keepdims: bool = False) -> Array: # https://github.com/pytorch/pytorch/issues/29137 if axis == (): return torch.clone(x) @@ -240,7 +231,15 @@ def min(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keep # torch.sort also returns a tuple # https://github.com/pytorch/pytorch/issues/70921 -def sort(x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True, **kwargs) -> Array: +def sort( + x: Array, + /, + *, + axis: int = -1, + descending: bool = False, + stable: bool = True, + **kwargs: object, +) -> Array: return torch.sort(x, dim=axis, descending=descending, stable=stable, **kwargs).values def _normalize_axes(axis, ndim): @@ -307,10 +306,10 @@ def _sum_prod_no_axis(x: Array, dtype: DType | None) -> Array: def prod(x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - dtype: Optional[DType] = None, + axis: int | tuple[int, ...] | None = None, + dtype: DType | None = None, keepdims: bool = False, - **kwargs) -> Array: + **kwargs: object) -> Array: if axis == (): return _sum_prod_no_axis(x, dtype) @@ -331,10 +330,10 @@ def prod(x: Array, def sum(x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - dtype: Optional[DType] = None, + axis: int | tuple[int, ...] | None = None, + dtype: DType | None = None, keepdims: bool = False, - **kwargs) -> Array: + **kwargs: object) -> Array: if axis == (): return _sum_prod_no_axis(x, dtype) @@ -350,9 +349,9 @@ def sum(x: Array, def any(x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, + axis: int | tuple[int, ...] | None = None, keepdims: bool = False, - **kwargs) -> Array: + **kwargs: object) -> Array: if axis == (): return x.to(torch.bool) @@ -374,9 +373,9 @@ def any(x: Array, def all(x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, + axis: int | tuple[int, ...] | None = None, keepdims: bool = False, - **kwargs) -> Array: + **kwargs: object) -> Array: if axis == (): return x.to(torch.bool) @@ -398,9 +397,9 @@ def all(x: Array, def mean(x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, + axis: int | tuple[int, ...] | None = None, keepdims: bool = False, - **kwargs) -> Array: + **kwargs: object) -> Array: # https://github.com/pytorch/pytorch/issues/29137 if axis == (): return torch.clone(x) @@ -415,10 +414,10 @@ def mean(x: Array, def std(x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - correction: Union[int, float] = 0.0, + axis: int | tuple[int, ...] | None = None, + correction: float = 0.0, keepdims: bool = False, - **kwargs) -> Array: + **kwargs: object) -> Array: # Note, float correction is not supported # https://github.com/pytorch/pytorch/issues/61492. We don't try to # implement it here for now. @@ -446,10 +445,10 @@ def std(x: Array, def var(x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - correction: Union[int, float] = 0.0, + axis: int | tuple[int, ...] | None = None, + correction: float = 0.0, keepdims: bool = False, - **kwargs) -> Array: + **kwargs: object) -> Array: # Note, float correction is not supported # https://github.com/pytorch/pytorch/issues/61492. We don't try to # implement it here for now. @@ -472,11 +471,11 @@ def var(x: Array, # torch.concat doesn't support dim=None # https://github.com/pytorch/pytorch/issues/70925 -def concat(arrays: Union[Tuple[Array, ...], List[Array]], +def concat(arrays: tuple[Array, ...] | list[Array], /, *, - axis: Optional[int] = 0, - **kwargs) -> Array: + axis: int | None = 0, + **kwargs: object) -> Array: if axis is None: arrays = tuple(ar.flatten() for ar in arrays) axis = 0 @@ -485,7 +484,7 @@ def concat(arrays: Union[Tuple[Array, ...], List[Array]], # torch.squeeze only accepts int dim and doesn't require it # https://github.com/pytorch/pytorch/issues/70924. Support for tuple dim was # added at https://github.com/pytorch/pytorch/pull/89017. -def squeeze(x: Array, /, axis: Union[int, Tuple[int, ...]]) -> Array: +def squeeze(x: Array, /, axis: int | tuple[int, ...]) -> Array: if isinstance(axis, int): axis = (axis,) for a in axis: @@ -499,27 +498,27 @@ def squeeze(x: Array, /, axis: Union[int, Tuple[int, ...]]) -> Array: return x # torch.broadcast_to uses size instead of shape -def broadcast_to(x: Array, /, shape: Tuple[int, ...], **kwargs) -> Array: +def broadcast_to(x: Array, /, shape: tuple[int, ...], **kwargs: object) -> Array: return torch.broadcast_to(x, shape, **kwargs) # torch.permute uses dims instead of axes -def permute_dims(x: Array, /, axes: Tuple[int, ...]) -> Array: +def permute_dims(x: Array, /, axes: tuple[int, ...]) -> Array: return torch.permute(x, axes) # The axis parameter doesn't work for flip() and roll() # https://github.com/pytorch/pytorch/issues/71210. Also torch.flip() doesn't # accept axis=None -def flip(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs) -> Array: +def flip(x: Array, /, *, axis: int | tuple[int, ...] | None = None, **kwargs: object) -> Array: if axis is None: axis = tuple(range(x.ndim)) # torch.flip doesn't accept dim as an int but the method does # https://github.com/pytorch/pytorch/issues/18095 return x.flip(axis, **kwargs) -def roll(x: Array, /, shift: Union[int, Tuple[int, ...]], *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs) -> Array: +def roll(x: Array, /, shift: int | tuple[int, ...], *, axis: int | tuple[int, ...] | None = None, **kwargs: object) -> Array: return torch.roll(x, shift, axis, **kwargs) -def nonzero(x: Array, /, **kwargs) -> Tuple[Array, ...]: +def nonzero(x: Array, /, **kwargs: object) -> tuple[Array, ...]: if x.ndim == 0: raise ValueError("nonzero() does not support zero-dimensional arrays") return torch.nonzero(x, as_tuple=True, **kwargs) @@ -532,8 +531,8 @@ def diff( *, axis: int = -1, n: int = 1, - prepend: Optional[Array] = None, - append: Optional[Array] = None, + prepend: Array | None = None, + append: Array | None = None, ) -> Array: return torch.diff(x, dim=axis, n=n, prepend=prepend, append=append) @@ -543,7 +542,7 @@ def count_nonzero( x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, + axis: int | tuple[int, ...] | None = None, keepdims: bool = False, ) -> Array: result = torch.count_nonzero(x, dim=axis) @@ -564,12 +563,7 @@ def repeat(x: Array, repeats: int | Array, /, *, axis: int | None = None) -> Arr return torch.repeat_interleave(x, repeats, axis) -def where( - condition: Array, - x1: Array | bool | int | float | complex, - x2: Array | bool | int | float | complex, - /, -) -> Array: +def where(condition: Array, x1: Array | complex, x2: Array | complex, /) -> Array: x1, x2 = _fix_promotion(x1, x2) return torch.where(condition, x1, x2) @@ -577,10 +571,10 @@ def where( # torch.reshape doesn't have the copy keyword def reshape(x: Array, /, - shape: Tuple[int, ...], + shape: tuple[int, ...], *, - copy: Optional[bool] = None, - **kwargs) -> Array: + copy: bool | None = None, + **kwargs: object) -> Array: if copy is not None: raise NotImplementedError("torch.reshape doesn't yet support the copy keyword") return torch.reshape(x, shape, **kwargs) @@ -589,14 +583,14 @@ def reshape(x: Array, # (https://github.com/pytorch/pytorch/issues/70915), and doesn't support some # keyword argument combinations # (https://github.com/pytorch/pytorch/issues/70914) -def arange(start: Union[int, float], +def arange(start: float, /, - stop: Optional[Union[int, float]] = None, - step: Union[int, float] = 1, + stop: float | None = None, + step: float = 1, *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs) -> Array: + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object) -> Array: if stop is None: start, stop = 0, start if step > 0 and stop <= start or step < 0 and stop >= start: @@ -611,13 +605,13 @@ def arange(start: Union[int, float], # torch.eye does not accept None as a default for the second argument and # doesn't support off-diagonals (https://github.com/pytorch/pytorch/issues/70910) def eye(n_rows: int, - n_cols: Optional[int] = None, + n_cols: int | None = None, /, *, k: int = 0, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs) -> Array: + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object) -> Array: if n_cols is None: n_cols = n_rows z = torch.zeros(n_rows, n_cols, dtype=dtype, device=device, **kwargs) @@ -626,52 +620,52 @@ def eye(n_rows: int, return z # torch.linspace doesn't have the endpoint parameter -def linspace(start: Union[int, float], - stop: Union[int, float], +def linspace(start: float, + stop: float, /, num: int, *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, + dtype: DType | None = None, + device: Device | None = None, endpoint: bool = True, - **kwargs) -> Array: + **kwargs: object) -> Array: if not endpoint: return torch.linspace(start, stop, num+1, dtype=dtype, device=device, **kwargs)[:-1] return torch.linspace(start, stop, num, dtype=dtype, device=device, **kwargs) # torch.full does not accept an int size # https://github.com/pytorch/pytorch/issues/70906 -def full(shape: Union[int, Tuple[int, ...]], - fill_value: bool | int | float | complex, +def full(shape: int | tuple[int, ...], + fill_value: complex, *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs) -> Array: + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object) -> Array: if isinstance(shape, int): shape = (shape,) return torch.full(shape, fill_value, dtype=dtype, device=device, **kwargs) # ones, zeros, and empty do not accept shape as a keyword argument -def ones(shape: Union[int, Tuple[int, ...]], +def ones(shape: int | tuple[int, ...], *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs) -> Array: + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object) -> Array: return torch.ones(shape, dtype=dtype, device=device, **kwargs) -def zeros(shape: Union[int, Tuple[int, ...]], +def zeros(shape: int | tuple[int, ...], *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs) -> Array: + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object) -> Array: return torch.zeros(shape, dtype=dtype, device=device, **kwargs) -def empty(shape: Union[int, Tuple[int, ...]], +def empty(shape: int | tuple[int, ...], *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs) -> Array: + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object) -> Array: return torch.empty(shape, dtype=dtype, device=device, **kwargs) # tril and triu do not call the keyword argument k @@ -693,14 +687,14 @@ def astype( /, *, copy: bool = True, - device: Optional[Device] = None, + device: Device | None = None, ) -> Array: if device is not None: return x.to(device, dtype=dtype, copy=copy) return x.to(dtype=dtype, copy=copy) -def broadcast_arrays(*arrays: Array) -> List[Array]: +def broadcast_arrays(*arrays: Array) -> list[Array]: shape = torch.broadcast_shapes(*[a.shape for a in arrays]) return [torch.broadcast_to(a, shape) for a in arrays] @@ -738,7 +732,7 @@ def unique_inverse(x: Array) -> UniqueInverseResult: def unique_values(x: Array) -> Array: return torch.unique(x) -def matmul(x1: Array, x2: Array, /, **kwargs) -> Array: +def matmul(x1: Array, x2: Array, /, **kwargs: object) -> Array: # torch.matmul doesn't type promote (but differently from _fix_promotion) x1, x2 = _fix_promotion(x1, x2, only_scalar=False) return torch.matmul(x1, x2, **kwargs) @@ -756,8 +750,8 @@ def tensordot( x2: Array, /, *, - axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2, - **kwargs, + axes: int | tuple[Sequence[int], Sequence[int]] = 2, + **kwargs: object, ) -> Array: # Note: torch.tensordot fails with integer dtypes when there is only 1 # element in the axis (https://github.com/pytorch/pytorch/issues/84530). @@ -766,8 +760,10 @@ def tensordot( def isdtype( - dtype: DType, kind: Union[DType, str, Tuple[Union[DType, str], ...]], - *, _tuple=True, # Disallow nested tuples + dtype: DType, + kind: DType | str | tuple[DType | str, ...], + *, + _tuple: bool = True, # Disallow nested tuples ) -> bool: """ Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``. @@ -801,7 +797,7 @@ def isdtype( else: return dtype == kind -def take(x: Array, indices: Array, /, *, axis: Optional[int] = None, **kwargs) -> Array: +def take(x: Array, indices: Array, /, *, axis: int | None = None, **kwargs: object) -> Array: if axis is None: if x.ndim != 1: raise ValueError("axis must be specified when ndim > 1") @@ -828,7 +824,7 @@ def sign(x: Array, /) -> Array: return out -def meshgrid(*arrays: Array, indexing: Literal['xy', 'ij'] = 'xy') -> List[Array]: +def meshgrid(*arrays: Array, indexing: Literal['xy', 'ij'] = 'xy') -> list[Array]: # enforce the default of 'xy' # TODO: is the return type a list or a tuple return list(torch.meshgrid(*arrays, indexing='xy')) diff --git a/array_api_compat/torch/fft.py b/array_api_compat/torch/fft.py index 50e6a0d0..ddf87c65 100644 --- a/array_api_compat/torch/fft.py +++ b/array_api_compat/torch/fft.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import Union, Sequence, Literal +from collections.abc import Sequence +from typing import Literal import torch import torch.fft @@ -17,7 +18,7 @@ def fftn( s: Sequence[int] = None, axes: Sequence[int] = None, norm: Literal["backward", "ortho", "forward"] = "backward", - **kwargs, + **kwargs: object, ) -> Array: return torch.fft.fftn(x, s=s, dim=axes, norm=norm, **kwargs) @@ -28,7 +29,7 @@ def ifftn( s: Sequence[int] = None, axes: Sequence[int] = None, norm: Literal["backward", "ortho", "forward"] = "backward", - **kwargs, + **kwargs: object, ) -> Array: return torch.fft.ifftn(x, s=s, dim=axes, norm=norm, **kwargs) @@ -39,7 +40,7 @@ def rfftn( s: Sequence[int] = None, axes: Sequence[int] = None, norm: Literal["backward", "ortho", "forward"] = "backward", - **kwargs, + **kwargs: object, ) -> Array: return torch.fft.rfftn(x, s=s, dim=axes, norm=norm, **kwargs) @@ -50,7 +51,7 @@ def irfftn( s: Sequence[int] = None, axes: Sequence[int] = None, norm: Literal["backward", "ortho", "forward"] = "backward", - **kwargs, + **kwargs: object, ) -> Array: return torch.fft.irfftn(x, s=s, dim=axes, norm=norm, **kwargs) @@ -58,8 +59,8 @@ def fftshift( x: Array, /, *, - axes: Union[int, Sequence[int]] = None, - **kwargs, + axes: int | Sequence[int] = None, + **kwargs: object, ) -> Array: return torch.fft.fftshift(x, dim=axes, **kwargs) @@ -67,8 +68,8 @@ def ifftshift( x: Array, /, *, - axes: Union[int, Sequence[int]] = None, - **kwargs, + axes: int | Sequence[int] = None, + **kwargs: object, ) -> Array: return torch.fft.ifftshift(x, dim=axes, **kwargs) diff --git a/array_api_compat/torch/linalg.py b/array_api_compat/torch/linalg.py index 70d72405..558cfe7b 100644 --- a/array_api_compat/torch/linalg.py +++ b/array_api_compat/torch/linalg.py @@ -1,8 +1,6 @@ from __future__ import annotations import torch -from typing import Optional, Union, Tuple - from torch.linalg import * # noqa: F403 # torch.linalg doesn't define __all__ @@ -32,7 +30,7 @@ def cross(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: x1, x2 = torch.broadcast_tensors(x1, x2) return torch_linalg.cross(x1, x2, dim=axis) -def vecdot(x1: Array, x2: Array, /, *, axis: int = -1, **kwargs) -> Array: +def vecdot(x1: Array, x2: Array, /, *, axis: int = -1, **kwargs: object) -> Array: from ._aliases import isdtype x1, x2 = _fix_promotion(x1, x2, only_scalar=False) @@ -54,7 +52,7 @@ def vecdot(x1: Array, x2: Array, /, *, axis: int = -1, **kwargs) -> Array: return res[..., 0, 0] return torch.linalg.vecdot(x1, x2, dim=axis, **kwargs) -def solve(x1: Array, x2: Array, /, **kwargs) -> Array: +def solve(x1: Array, x2: Array, /, **kwargs: object) -> Array: x1, x2 = _fix_promotion(x1, x2, only_scalar=False) # Torch tries to emulate NumPy 1 solve behavior by using batched 1-D solve # whenever @@ -75,7 +73,7 @@ def solve(x1: Array, x2: Array, /, **kwargs) -> Array: return torch.linalg.solve(x1, x2, **kwargs) # torch.trace doesn't support the offset argument and doesn't support stacking -def trace(x: Array, /, *, offset: int = 0, dtype: Optional[DType] = None) -> Array: +def trace(x: Array, /, *, offset: int = 0, dtype: DType | None = None) -> Array: # Use our wrapped sum to make sure it does upcasting correctly return sum(torch.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype) @@ -83,11 +81,11 @@ def vector_norm( x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, + axis: int | tuple[int, ...] | None = None, keepdims: bool = False, # JustFloat stands for inf | -inf, which are not valid for Literal ord: JustInt | JustFloat = 2, - **kwargs, + **kwargs: object, ) -> Array: # torch.vector_norm incorrectly treats axis=() the same as axis=None if axis == (): diff --git a/pyproject.toml b/pyproject.toml index aacebd11..ec054417 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,11 +43,11 @@ dev = [ "array-api-strict", "dask[array]>=2024.9.0", "jax[cpu]", + "ndonnx", "numpy>=1.22", "pytest", "torch", "sparse>=0.15.1", - "ndonnx" ] [project.urls] @@ -61,7 +61,7 @@ version = {attr = "array_api_compat.__version__"} include = ["array_api_compat*"] namespaces = false -[toolint] +[tool.ruff.lint] preview = true select = [ # Defaults @@ -79,20 +79,42 @@ ignore = [ "E722" ] -[tool.ruff.lint] -preview = true -select = [ -# Defaults -"E4", "E7", "E9", "F", -# Undefined export -"F822", -# Useless import alias -"PLC0414" -] -ignore = [ - # Module import not at top of file - "E402", - # Do not use bare `except` - "E722" +[tool.mypy] +files = ["array_api_compat"] +disallow_incomplete_defs = true +disallow_untyped_decorators = true +disallow_untyped_defs = false # TODO +ignore_missing_imports = false +no_implicit_optional = true +show_error_codes = true +warn_redundant_casts = true +warn_unused_ignores = true +warn_unreachable = true + +[[tool.mypy.overrides]] +module = ["cupy.*", "cupy_backends.*", "dask.*", "jax.*", "ndonnx.*", "sparse.*", "torch.*"] +ignore_missing_imports = true + + +[tool.pyright] +include = ["src", "tests"] +pythonPlatform = "All" + +reportAny = false +reportExplicitAny = false +# missing type stubs +reportAttributeAccessIssue = false +reportUnknownMemberType = false +reportUnknownVariableType = false +# Redundant with mypy checks +reportMissingImports = false +reportMissingTypeStubs = false +# false positives for input validation +reportUnreachable = false +# ruff handles this +reportUnusedParameter = false + +executionEnvironments = [ + { root = "array_api_compat" }, ] From cddc9ef8a19b453b09884987ca6a0626408a1478 Mon Sep 17 00:00:00 2001 From: Guido Imperiale Date: Fri, 6 Jun 2025 12:04:01 +0100 Subject: [PATCH 076/151] ENH: Review exported symbols; redesign `test_all` (#315) Review and discussion at https://github.com/data-apis/array-api-compat/pull/315 --- array_api_compat/_internal.py | 20 +- array_api_compat/common/_aliases.py | 2 - array_api_compat/common/_helpers.py | 2 - array_api_compat/common/_linalg.py | 2 - array_api_compat/cupy/__init__.py | 13 +- array_api_compat/cupy/_aliases.py | 3 +- array_api_compat/cupy/_typing.py | 1 - array_api_compat/cupy/fft.py | 7 +- array_api_compat/cupy/linalg.py | 6 +- array_api_compat/dask/array/__init__.py | 16 +- array_api_compat/dask/array/_aliases.py | 4 - array_api_compat/dask/array/fft.py | 19 +- array_api_compat/dask/array/linalg.py | 35 +-- array_api_compat/numpy/__init__.py | 22 +- array_api_compat/numpy/_aliases.py | 6 +- array_api_compat/numpy/_typing.py | 1 - array_api_compat/numpy/fft.py | 15 +- array_api_compat/numpy/linalg.py | 29 +- array_api_compat/torch/__init__.py | 29 +- array_api_compat/torch/_aliases.py | 5 +- array_api_compat/torch/fft.py | 16 +- array_api_compat/torch/linalg.py | 19 +- tests/test_all.py | 360 ++++++++++++++++++++---- 23 files changed, 435 insertions(+), 197 deletions(-) diff --git a/array_api_compat/_internal.py b/array_api_compat/_internal.py index b1925492..baa39ded 100644 --- a/array_api_compat/_internal.py +++ b/array_api_compat/_internal.py @@ -2,6 +2,7 @@ Internal helpers """ +import importlib from collections.abc import Callable from functools import wraps from inspect import signature @@ -52,8 +53,25 @@ def wrapped_f(*args: object, **kwargs: object) -> object: return inner -__all__ = ["get_xp"] +def clone_module(mod_name: str, globals_: dict[str, object]) -> list[str]: + """Import everything from module, updating globals(). + Returns __all__. + """ + mod = importlib.import_module(mod_name) + # Neither of these two methods is sufficient by itself, + # depending on various idiosyncrasies of the libraries we're wrapping. + objs = {} + exec(f"from {mod.__name__} import *", objs) + + for n in dir(mod): + if not n.startswith("_") and hasattr(mod, n): + objs[n] = getattr(mod, n) + + globals_.update(objs) + return list(objs) + +__all__ = ["get_xp", "clone_module"] def __dir__() -> list[str]: return __all__ diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 51732b91..27b2604b 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -721,8 +721,6 @@ def iinfo(type_: DType | Array, /, xp: Namespace) -> Any: "finfo", "iinfo", ] -_all_ignore = ["inspect", "array_namespace", "NamedTuple"] - def __dir__() -> list[str]: return __all__ diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index cae0ee0b..37f31ec2 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -1062,7 +1062,5 @@ def is_lazy_array(x: object) -> TypeGuard[_ArrayApiObj]: "to_device", ] -_all_ignore = ['lru_cache', 'sys', 'math', 'inspect', 'warnings'] - def __dir__() -> list[str]: return __all__ diff --git a/array_api_compat/common/_linalg.py b/array_api_compat/common/_linalg.py index 3fd9d860..69672af7 100644 --- a/array_api_compat/common/_linalg.py +++ b/array_api_compat/common/_linalg.py @@ -225,8 +225,6 @@ def trace( 'matrix_transpose', 'svdvals', 'vecdot', 'vector_norm', 'diagonal', 'trace'] -_all_ignore = ['math', 'normalize_axis_tuple', 'get_xp', 'np', 'isdtype'] - def __dir__() -> list[str]: return __all__ diff --git a/array_api_compat/cupy/__init__.py b/array_api_compat/cupy/__init__.py index 9a30f95d..af003c5a 100644 --- a/array_api_compat/cupy/__init__.py +++ b/array_api_compat/cupy/__init__.py @@ -1,3 +1,4 @@ +from typing import Final from cupy import * # noqa: F403 # from cupy import * doesn't overwrite these builtin names @@ -5,9 +6,19 @@ # These imports may overwrite names from the import * above. from ._aliases import * # noqa: F403 +from ._info import __array_namespace_info__ # noqa: F401 # See the comment in the numpy __init__.py __import__(__package__ + '.linalg') __import__(__package__ + '.fft') -__array_api_version__ = '2024.12' +__array_api_version__: Final = '2024.12' + +__all__ = sorted( + {name for name in globals() if not name.startswith("__")} + - {"Final", "_aliases", "_info", "_typing"} + | {"__array_api_version__", "__array_namespace_info__", "linalg", "fft"} +) + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index c0473ca4..2752bd98 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -7,7 +7,6 @@ from ..common import _aliases, _helpers from ..common._typing import NestedSequence, SupportsBufferProtocol from .._internal import get_xp -from ._info import __array_namespace_info__ from ._typing import Array, Device, DType bool = cp.bool_ @@ -141,7 +140,7 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array: else: unstack = get_xp(cp)(_aliases.unstack) -__all__ = _aliases.__all__ + ['__array_namespace_info__', 'asarray', 'astype', +__all__ = _aliases.__all__ + ['asarray', 'astype', 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atan2', 'atanh', 'bitwise_left_shift', 'bitwise_invert', 'bitwise_right_shift', diff --git a/array_api_compat/cupy/_typing.py b/array_api_compat/cupy/_typing.py index d8e49ca7..e5c202dc 100644 --- a/array_api_compat/cupy/_typing.py +++ b/array_api_compat/cupy/_typing.py @@ -1,7 +1,6 @@ from __future__ import annotations __all__ = ["Array", "DType", "Device"] -_all_ignore = ["cp"] from typing import TYPE_CHECKING diff --git a/array_api_compat/cupy/fft.py b/array_api_compat/cupy/fft.py index 2bd11940..53a9a454 100644 --- a/array_api_compat/cupy/fft.py +++ b/array_api_compat/cupy/fft.py @@ -31,7 +31,6 @@ __all__ = fft_all + _fft.__all__ -del get_xp -del cp -del fft_all -del _fft +def __dir__() -> list[str]: + return __all__ + diff --git a/array_api_compat/cupy/linalg.py b/array_api_compat/cupy/linalg.py index 7bc3536e..da301574 100644 --- a/array_api_compat/cupy/linalg.py +++ b/array_api_compat/cupy/linalg.py @@ -43,7 +43,5 @@ __all__ = linalg_all + _linalg.__all__ -del get_xp -del cp -del linalg_all -del _linalg +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/dask/array/__init__.py b/array_api_compat/dask/array/__init__.py index 6d2ea7cd..f78aa8b3 100644 --- a/array_api_compat/dask/array/__init__.py +++ b/array_api_compat/dask/array/__init__.py @@ -1,12 +1,26 @@ from typing import Final -from dask.array import * # noqa: F403 +from ..._internal import clone_module + +__all__ = clone_module("dask.array", globals()) # These imports may overwrite names from the import * above. +from . import _aliases from ._aliases import * # type: ignore[assignment] # noqa: F403 +from ._info import __array_namespace_info__ # noqa: F401 __array_api_version__: Final = "2024.12" +del Final # See the comment in the numpy __init__.py __import__(__package__ + '.linalg') __import__(__package__ + '.fft') + +__all__ = sorted( + set(__all__) + | set(_aliases.__all__) + | {"__array_api_version__", "__array_namespace_info__", "linalg", "fft"} +) + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index bc0302fe..4d1e7341 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -41,7 +41,6 @@ NestedSequence, SupportsBufferProtocol, ) -from ._info import __array_namespace_info__ isdtype = get_xp(np)(_aliases.isdtype) unstack = get_xp(da)(_aliases.unstack) @@ -355,7 +354,6 @@ def count_nonzero( __all__ = [ - "__array_namespace_info__", "count_nonzero", "bool", "int8", "int16", "int32", "int64", @@ -369,8 +367,6 @@ def count_nonzero( "bitwise_left_shift", "bitwise_right_shift", "bitwise_invert", ] # fmt: skip __all__ += _aliases.__all__ -_all_ignore = ["array_namespace", "get_xp", "da", "np"] - def __dir__() -> list[str]: return __all__ diff --git a/array_api_compat/dask/array/fft.py b/array_api_compat/dask/array/fft.py index 68c4280e..44b68e73 100644 --- a/array_api_compat/dask/array/fft.py +++ b/array_api_compat/dask/array/fft.py @@ -1,13 +1,6 @@ -from dask.array.fft import * # noqa: F403 -# dask.array.fft doesn't have __all__. If it is added, replace this with -# -# from dask.array.fft import __all__ as linalg_all -_n: dict[str, object] = {} -exec('from dask.array.fft import *', _n) -for k in ("__builtins__", "Sequence", "annotations", "warnings"): - _n.pop(k, None) -fft_all = list(_n) -del _n, k +from ..._internal import clone_module + +__all__ = clone_module("dask.array.fft", globals()) from ...common import _fft from ..._internal import get_xp @@ -17,5 +10,7 @@ fftfreq = get_xp(da)(_fft.fftfreq) rfftfreq = get_xp(da)(_fft.rfftfreq) -__all__ = fft_all + ["fftfreq", "rfftfreq"] -_all_ignore = ["da", "fft_all", "get_xp", "warnings"] +__all__ += ["fftfreq", "rfftfreq"] + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/dask/array/linalg.py b/array_api_compat/dask/array/linalg.py index 06f596bc..6b3c1011 100644 --- a/array_api_compat/dask/array/linalg.py +++ b/array_api_compat/dask/array/linalg.py @@ -4,27 +4,17 @@ import dask.array as da -# Exports -from dask.array.linalg import * # noqa: F403 -from dask.array import outer # The `matmul` and `tensordot` functions are in both the main and linalg namespaces -from dask.array import matmul, tensordot - +from dask.array import matmul, outer, tensordot -from ..._internal import get_xp +# Exports +from ..._internal import clone_module, get_xp from ...common import _linalg from ...common._typing import Array -from ._aliases import matrix_transpose, vecdot -# dask.array.linalg doesn't have __all__. If it is added, replace this with -# -# from dask.array.linalg import __all__ as linalg_all -_n: dict[str, object] = {} -exec('from dask.array.linalg import *', _n) -for k in ('__builtins__', 'annotations', 'operator', 'warnings', 'Array'): - _n.pop(k, None) -linalg_all = list(_n) -del _n, k +__all__ = clone_module("dask.array.linalg", globals()) + +from ._aliases import matrix_transpose, vecdot EighResult = _linalg.EighResult QRResult = _linalg.QRResult @@ -64,10 +54,11 @@ def svdvals(x: Array) -> Array: vector_norm = get_xp(da)(_linalg.vector_norm) diagonal = get_xp(da)(_linalg.diagonal) -__all__ = linalg_all + ["trace", "outer", "matmul", "tensordot", - "matrix_transpose", "vecdot", "EighResult", - "QRResult", "SlogdetResult", "SVDResult", "qr", - "cholesky", "matrix_rank", "matrix_norm", "svdvals", - "vector_norm", "diagonal"] +__all__ += ["trace", "outer", "matmul", "tensordot", + "matrix_transpose", "vecdot", "EighResult", + "QRResult", "SlogdetResult", "SVDResult", "qr", + "cholesky", "matrix_rank", "matrix_norm", "svdvals", + "vector_norm", "diagonal"] -_all_ignore = ['get_xp', 'da', 'linalg_all', 'warnings'] +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/numpy/__init__.py b/array_api_compat/numpy/__init__.py index bf43fe61..23379e44 100644 --- a/array_api_compat/numpy/__init__.py +++ b/array_api_compat/numpy/__init__.py @@ -1,16 +1,17 @@ # ruff: noqa: PLC0414 from typing import Final -from numpy import * # noqa: F403 # pyright: ignore[reportWildcardImportFromLibrary] +from .._internal import clone_module -# from numpy import * doesn't overwrite these builtin names -from numpy import abs as abs -from numpy import max as max -from numpy import min as min -from numpy import round as round +# This needs to be loaded explicitly before cloning +import numpy.typing # noqa: F401 + +__all__ = clone_module("numpy", globals()) # These imports may overwrite names from the import * above. +from . import _aliases from ._aliases import * # type: ignore[assignment,no-redef] # noqa: F403 +from ._info import __array_namespace_info__ # noqa: F401 # Don't know why, but we have to do an absolute import to import linalg. If we # instead do @@ -26,3 +27,12 @@ from .linalg import matrix_transpose, vecdot # type: ignore[no-redef] # noqa: F401 __array_api_version__: Final = "2024.12" + +__all__ = sorted( + set(__all__) + | set(_aliases.__all__) + | {"__array_api_version__", "__array_namespace_info__", "linalg", "fft"} +) + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index 5a05a820..5bb8869a 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -9,7 +9,6 @@ from .._internal import get_xp from ..common import _aliases, _helpers from ..common._typing import NestedSequence, SupportsBufferProtocol -from ._info import __array_namespace_info__ from ._typing import Array, Device, DType bool = np.bool_ @@ -147,8 +146,7 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array: else: unstack = get_xp(np)(_aliases.unstack) -__all__ = [ - "__array_namespace_info__", +__all__ = _aliases.__all__ + [ "asarray", "astype", "acos", @@ -167,8 +165,6 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array: "pow", "take_along_axis" ] -__all__ += _aliases.__all__ -_all_ignore = ["np", "get_xp"] def __dir__() -> list[str]: diff --git a/array_api_compat/numpy/_typing.py b/array_api_compat/numpy/_typing.py index e771c788..b5fa188c 100644 --- a/array_api_compat/numpy/_typing.py +++ b/array_api_compat/numpy/_typing.py @@ -23,7 +23,6 @@ Array: TypeAlias = np.ndarray __all__ = ["Array", "DType", "Device"] -_all_ignore = ["np"] def __dir__() -> list[str]: diff --git a/array_api_compat/numpy/fft.py b/array_api_compat/numpy/fft.py index 06875f00..a492feb8 100644 --- a/array_api_compat/numpy/fft.py +++ b/array_api_compat/numpy/fft.py @@ -1,6 +1,8 @@ import numpy as np -from numpy.fft import __all__ as fft_all -from numpy.fft import fft2, ifft2, irfft2, rfft2 + +from .._internal import clone_module + +__all__ = clone_module("numpy.fft", globals()) from .._internal import get_xp from ..common import _fft @@ -21,15 +23,8 @@ ifftshift = get_xp(np)(_fft.ifftshift) -__all__ = ["rfft2", "irfft2", "fft2", "ifft2"] -__all__ += _fft.__all__ - +__all__ = sorted(set(__all__) | set(_fft.__all__)) def __dir__() -> list[str]: return __all__ - -del get_xp -del np -del fft_all -del _fft diff --git a/array_api_compat/numpy/linalg.py b/array_api_compat/numpy/linalg.py index 9a618be9..7168441c 100644 --- a/array_api_compat/numpy/linalg.py +++ b/array_api_compat/numpy/linalg.py @@ -7,26 +7,11 @@ import numpy as np -# intersection of `np.linalg.__all__` on numpy 1.22 and 2.2, minus `_linalg.__all__` -from numpy.linalg import ( - LinAlgError, - cond, - det, - eig, - eigvals, - eigvalsh, - inv, - lstsq, - matrix_power, - multi_dot, - norm, - tensorinv, - tensorsolve, -) - -from .._internal import get_xp +from .._internal import clone_module, get_xp from ..common import _linalg +__all__ = clone_module("numpy.linalg", globals()) + # These functions are in both the main and linalg namespaces from ._aliases import matmul, matrix_transpose, tensordot, vecdot # noqa: F401 from ._typing import Array @@ -120,7 +105,7 @@ def solve(x1: Array, x2: Array, /) -> Array: vector_norm = get_xp(np)(_linalg.vector_norm) -__all__ = [ +_all = [ "LinAlgError", "cond", "det", @@ -132,12 +117,12 @@ def solve(x1: Array, x2: Array, /) -> Array: "matrix_power", "multi_dot", "norm", + "solve", "tensorinv", "tensorsolve", + "vector_norm", ] -__all__ += _linalg.__all__ -__all__ += ["solve", "vector_norm"] - +__all__ = sorted(set(__all__) | set(_linalg.__all__) | set(_all)) def __dir__() -> list[str]: return __all__ diff --git a/array_api_compat/torch/__init__.py b/array_api_compat/torch/__init__.py index 69fd19ce..6cbb6ec2 100644 --- a/array_api_compat/torch/__init__.py +++ b/array_api_compat/torch/__init__.py @@ -1,22 +1,25 @@ -from torch import * # noqa: F403 +from typing import Final -# Several names are not included in the above import * -import torch -for n in dir(torch): - if (n.startswith('_') - or n.endswith('_') - or 'cuda' in n - or 'cpu' in n - or 'backward' in n): - continue - exec(f"{n} = torch.{n}") -del n +from .._internal import clone_module + +__all__ = clone_module("torch", globals()) # These imports may overwrite names from the import * above. +from . import _aliases from ._aliases import * # noqa: F403 +from ._info import __array_namespace_info__ # noqa: F401 # See the comment in the numpy __init__.py __import__(__package__ + '.linalg') __import__(__package__ + '.fft') -__array_api_version__ = '2024.12' +__array_api_version__: Final = '2024.12' + +__all__ = sorted( + set(__all__) + | set(_aliases.__all__) + | {"__array_api_version__", "__array_namespace_info__", "linalg", "fft"} +) + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 7a449001..91161d24 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -10,7 +10,6 @@ from .._internal import get_xp from ..common import _aliases from ..common._typing import NestedSequence, SupportsBufferProtocol -from ._info import __array_namespace_info__ from ._typing import Array, Device, DType _int_dtypes = { @@ -830,7 +829,7 @@ def meshgrid(*arrays: Array, indexing: Literal['xy', 'ij'] = 'xy') -> list[Array return list(torch.meshgrid(*arrays, indexing='xy')) -__all__ = ['__array_namespace_info__', 'asarray', 'result_type', 'can_cast', +__all__ = ['asarray', 'result_type', 'can_cast', 'permute_dims', 'bitwise_invert', 'newaxis', 'conj', 'add', 'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'copysign', 'count_nonzero', @@ -847,5 +846,3 @@ def meshgrid(*arrays: Array, indexing: Literal['xy', 'ij'] = 'xy') -> list[Array 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values', 'matmul', 'matrix_transpose', 'vecdot', 'tensordot', 'isdtype', 'take', 'take_along_axis', 'sign', 'finfo', 'iinfo', 'repeat', 'meshgrid'] - -_all_ignore = ['torch', 'get_xp'] diff --git a/array_api_compat/torch/fft.py b/array_api_compat/torch/fft.py index ddf87c65..76342980 100644 --- a/array_api_compat/torch/fft.py +++ b/array_api_compat/torch/fft.py @@ -5,9 +5,11 @@ import torch import torch.fft -from torch.fft import * # noqa: F403 from ._typing import Array +from .._internal import clone_module + +__all__ = clone_module("torch.fft", globals()) # Several torch fft functions do not map axes to dim @@ -74,13 +76,7 @@ def ifftshift( return torch.fft.ifftshift(x, dim=axes, **kwargs) -__all__ = torch.fft.__all__ + [ - "fftn", - "ifftn", - "rfftn", - "irfftn", - "fftshift", - "ifftshift", -] +__all__ += ["fftn", "ifftn", "rfftn", "irfftn", "fftshift", "ifftshift"] -_all_ignore = ['torch'] +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/torch/linalg.py b/array_api_compat/torch/linalg.py index 558cfe7b..08271d22 100644 --- a/array_api_compat/torch/linalg.py +++ b/array_api_compat/torch/linalg.py @@ -1,12 +1,11 @@ from __future__ import annotations import torch -from torch.linalg import * # noqa: F403 +import torch.linalg -# torch.linalg doesn't define __all__ -# from torch.linalg import __all__ as linalg_all -from torch import linalg as torch_linalg -linalg_all = [i for i in dir(torch_linalg) if not i.startswith('_')] +from .._internal import clone_module + +__all__ = clone_module("torch.linalg", globals()) # outer is implemented in torch but aren't in the linalg namespace from torch import outer @@ -28,7 +27,7 @@ def cross(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: if not (x1.shape[axis] == x2.shape[axis] == 3): raise ValueError(f"cross product axis must have size 3, got {x1.shape[axis]} and {x2.shape[axis]}") x1, x2 = torch.broadcast_tensors(x1, x2) - return torch_linalg.cross(x1, x2, dim=axis) + return torch.linalg.cross(x1, x2, dim=axis) def vecdot(x1: Array, x2: Array, /, *, axis: int = -1, **kwargs: object) -> Array: from ._aliases import isdtype @@ -108,12 +107,8 @@ def vector_norm( return out return torch.linalg.vector_norm(x, ord=ord, axis=axis, keepdim=keepdims, **kwargs) -__all__ = linalg_all + ['outer', 'matmul', 'matrix_transpose', 'tensordot', - 'cross', 'vecdot', 'solve', 'trace', 'vector_norm'] - -_all_ignore = ['torch_linalg', 'sum'] - -del linalg_all +__all__ += ['outer', 'matmul', 'matrix_transpose', 'tensordot', + 'cross', 'vecdot', 'solve', 'trace', 'vector_norm'] def __dir__() -> list[str]: return __all__ diff --git a/tests/test_all.py b/tests/test_all.py index 271cd189..c36aef67 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -1,63 +1,311 @@ -""" -Test that files that define __all__ aren't missing any exports. +"""Test exported names""" -You can add names that shouldn't be exported to _all_ignore, like +import builtins -_all_ignore = ['sys'] +import numpy as np +import pytest -This is preferable to del-ing the names as this will break any name that is -used inside of a function. Note that names starting with an underscore are automatically ignored. -""" +from array_api_compat._internal import clone_module +from ._helpers import wrapped_libraries -import sys +NAMES = { + "": [ + # Inspection + "__array_api_version__", + "__array_namespace_info__", + # Submodules + "fft", + "linalg", + # Constants + "e", + "inf", + "nan", + "newaxis", + "pi", + # Creation Functions + "arange", + "asarray", + "empty", + "empty_like", + "eye", + "from_dlpack", + "full", + "full_like", + "linspace", + "meshgrid", + "ones", + "ones_like", + "tril", + "triu", + "zeros", + "zeros_like", + # Data Type Functions + "astype", + "can_cast", + "finfo", + "iinfo", + "isdtype", + "result_type", + # Data Types + "bool", + "int8", + "int16", + "int32", + "int64", + "uint8", + "uint16", + "uint32", + "uint64", + "float32", + "float64", + "complex64", + "complex128", + # Elementwise Functions + "abs", + "acos", + "acosh", + "add", + "asin", + "asinh", + "atan", + "atan2", + "atanh", + "bitwise_and", + "bitwise_left_shift", + "bitwise_invert", + "bitwise_or", + "bitwise_right_shift", + "bitwise_xor", + "ceil", + "clip", + "conj", + "copysign", + "cos", + "cosh", + "divide", + "equal", + "exp", + "expm1", + "floor", + "floor_divide", + "greater", + "greater_equal", + "hypot", + "imag", + "isfinite", + "isinf", + "isnan", + "less", + "less_equal", + "log", + "log1p", + "log2", + "log10", + "logaddexp", + "logical_and", + "logical_not", + "logical_or", + "logical_xor", + "maximum", + "minimum", + "multiply", + "negative", + "nextafter", + "not_equal", + "positive", + "pow", + "real", + "reciprocal", + "remainder", + "round", + "sign", + "signbit", + "sin", + "sinh", + "square", + "sqrt", + "subtract", + "tan", + "tanh", + "trunc", + # Indexing Functions + "take", + "take_along_axis", + # Linear Algebra Functions + "matmul", + "matrix_transpose", + "tensordot", + "vecdot", + # Manipulation Functions + "broadcast_arrays", + "broadcast_to", + "concat", + "expand_dims", + "flip", + "moveaxis", + "permute_dims", + "repeat", + "reshape", + "roll", + "squeeze", + "stack", + "tile", + "unstack", + # Searching Functions + "argmax", + "argmin", + "count_nonzero", + "nonzero", + "searchsorted", + "where", + # Set Functions + "unique_all", + "unique_counts", + "unique_inverse", + "unique_values", + # Sorting Functions + "argsort", + "sort", + # Statistical Functions + "cumulative_prod", + "cumulative_sum", + "max", + "mean", + "min", + "prod", + "std", + "sum", + "var", + # Utility Functions + "all", + "any", + "diff", + ], + "fft": [ + "fft", + "ifft", + "fftn", + "ifftn", + "rfft", + "irfft", + "rfftn", + "irfftn", + "hfft", + "ihfft", + "fftfreq", + "rfftfreq", + "fftshift", + "ifftshift", + ], + "linalg": [ + "cholesky", + "cross", + "det", + "diagonal", + "eigh", + "eigvalsh", + "inv", + "matmul", + "matrix_norm", + "matrix_power", + "matrix_rank", + "matrix_transpose", + "outer", + "pinv", + "qr", + "slogdet", + "solve", + "svd", + "svdvals", + "tensordot", + "trace", + "vecdot", + "vector_norm", + ], +} -from ._helpers import import_, wrapped_libraries +XFAILS = { + ("numpy", ""): ["from_dlpack"] if np.__version__ < "1.23" else [], + ("dask.array", ""): ["from_dlpack", "take_along_axis"], + ("dask.array", "linalg"): [ + "cross", + "det", + "eigh", + "eigvalsh", + "matrix_power", + "pinv", + "slogdet", + ], +} -import pytest -import typing - -TYPING_NAMES = frozenset(( - "Array", - "Device", - "DType", - "Namespace", - "NestedSequence", - "SupportsBufferProtocol", -)) - -@pytest.mark.parametrize("library", ["common"] + wrapped_libraries) -def test_all(library): - if library == "common": - import array_api_compat.common # noqa: F401 - else: - import_(library, wrapper=True) - - # NB: iterate over a copy to avoid a "dictionary size changed" error - for mod_name in sys.modules.copy(): - if not mod_name.startswith('array_api_compat.' + library): - continue - - module = sys.modules[mod_name] - - # TODO: We should define __all__ in the __init__.py files and test it - # there too. - if not hasattr(module, '__all__'): - continue - - dir_names = [n for n in dir(module) if not n.startswith('_')] - if '__array_namespace_info__' in dir(module): - dir_names.append('__array_namespace_info__') - ignore_all_names = set(getattr(module, '_all_ignore', ())) - ignore_all_names |= set(dir(typing)) - ignore_all_names |= {"annotations"} - if not module.__name__.endswith("._typing"): - ignore_all_names |= TYPING_NAMES - dir_names = set(dir_names) - set(ignore_all_names) - all_names = module.__all__ - - if set(dir_names) != set(all_names): - extra_dir = set(dir_names) - set(all_names) - extra_all = set(all_names) - set(dir_names) - assert not extra_dir, f"Some dir() names not included in __all__ for {mod_name}: {extra_dir}" - assert not extra_all, f"Some __all__ names not in dir() for {mod_name}: {extra_all}" + +def all_names(mod): + """Return all names available in a module.""" + objs = {} + clone_module(mod.__name__, objs) + return set(objs) + + +def get_mod(library, module, *, compat): + if compat: + library = f"array_api_compat.{library}" + xp = pytest.importorskip(library) + return getattr(xp, module) if module else xp + + +@pytest.mark.parametrize("module", list(NAMES)) +@pytest.mark.parametrize("library", wrapped_libraries) +def test_array_api_names(library, module): + """Test that __all__ isn't missing any exports + dictated by the Standard. + """ + mod = get_mod(library, module, compat=True) + missing = set(NAMES[module]) - all_names(mod) + xfail = set(XFAILS.get((library, module), [])) + xpass = xfail - missing + fails = missing - xfail + assert not xpass, f"Names in XFAILS are defined: {xpass}" + assert not fails, f"Missing exports: {fails}" + + +@pytest.mark.parametrize("module", list(NAMES)) +@pytest.mark.parametrize("library", wrapped_libraries) +def test_compat_doesnt_hide_names(library, module): + """The base namespace can have more names than the ones explicitly exported + by array-api-compat. Test that we're not suppressing them. + """ + bare_mod = get_mod(library, module, compat=False) + compat_mod = get_mod(library, module, compat=True) + + missing = all_names(bare_mod) - all_names(compat_mod) + missing = {name for name in missing if not name.startswith("_")} + assert not missing, f"Non-Array API names have been hidden: {missing}" + + +@pytest.mark.parametrize("module", list(NAMES)) +@pytest.mark.parametrize("library", wrapped_libraries) +def test_compat_doesnt_add_names(library, module): + """Test that array-api-compat isn't adding names to the namespace + besides those defined by the Array API Standard. + """ + bare_mod = get_mod(library, module, compat=False) + compat_mod = get_mod(library, module, compat=True) + + aapi_names = set(NAMES[module]) + spurious = all_names(compat_mod) - all_names(bare_mod) - aapi_names + # Quietly ignore *Result dataclasses + spurious = {name for name in spurious if not name.endswith("Result")} + assert not spurious, ( + f"array-api-compat is adding non-Array API names: {spurious}" + ) + + +@pytest.mark.parametrize( + "name", [name for name in NAMES[""] if hasattr(builtins, name)] +) +@pytest.mark.parametrize("library", wrapped_libraries) +def test_builtins_collision(library, name): + """Test that xp.bool is not accidentally builtins.bool, etc.""" + xp = pytest.importorskip(f"array_api_compat.{library}") + assert getattr(xp, name) is not getattr(builtins, name) From 1d1178d33f7af737abf697a76fb161901faa075d Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 16 Jun 2025 10:16:28 +0000 Subject: [PATCH 077/151] Bump dawidd6/action-download-artifact from 10 to 11 in the actions group Bumps the actions group with 1 update: [dawidd6/action-download-artifact](https://github.com/dawidd6/action-download-artifact). Updates `dawidd6/action-download-artifact` from 10 to 11 - [Release notes](https://github.com/dawidd6/action-download-artifact/releases) - [Commits](https://github.com/dawidd6/action-download-artifact/compare/v10...v11) --- updated-dependencies: - dependency-name: dawidd6/action-download-artifact dependency-version: '11' dependency-type: direct:production update-type: version-update:semver-major dependency-group: actions ... Signed-off-by: dependabot[bot] --- .github/workflows/docs-deploy.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/docs-deploy.yml b/.github/workflows/docs-deploy.yml index 4e3efb39..ed90b29d 100644 --- a/.github/workflows/docs-deploy.yml +++ b/.github/workflows/docs-deploy.yml @@ -13,7 +13,7 @@ jobs: steps: - uses: actions/checkout@v4 - name: Download Artifact - uses: dawidd6/action-download-artifact@v10 + uses: dawidd6/action-download-artifact@v11 with: workflow: docs-build.yml name: docs-build From d9c3646bfc53dfd37eb921f6af8fc533d029d9d3 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sun, 29 Jun 2025 10:39:30 +0200 Subject: [PATCH 078/151] BUG: torch/meshgrid: stop ignoring the "indexing" argument --- array_api_compat/torch/_aliases.py | 2 +- tests/test_torch.py | 13 +++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 91161d24..40960f45 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -826,7 +826,7 @@ def sign(x: Array, /) -> Array: def meshgrid(*arrays: Array, indexing: Literal['xy', 'ij'] = 'xy') -> list[Array]: # enforce the default of 'xy' # TODO: is the return type a list or a tuple - return list(torch.meshgrid(*arrays, indexing='xy')) + return list(torch.meshgrid(*arrays, indexing=indexing)) __all__ = ['asarray', 'result_type', 'can_cast', diff --git a/tests/test_torch.py b/tests/test_torch.py index 7adb4ab3..f661a272 100644 --- a/tests/test_torch.py +++ b/tests/test_torch.py @@ -117,3 +117,16 @@ def test_meshgrid(): assert Y.shape == Y_xy.shape assert xp.all(Y == Y_xy) + + # repeat with an explicit indexing + X, Y = xp.meshgrid(x, y, indexing='ij') + + # output of torch.meshgrid(x, y, indexing='ij') + X_ij, Y_ij = xp.asarray([[1], [2]]), xp.asarray([[4], [4]]) + + assert X.shape == X_ij.shape + assert xp.all(X == X_ij) + + assert Y.shape == Y_ij.shape + assert xp.all(Y == Y_ij) + From fa35e90a1b56a8338372373064af67c47db945f5 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sun, 29 Jun 2025 11:54:16 +0200 Subject: [PATCH 079/151] CI: some dask tests require numpy >= 3.12 --- tests/test_dask.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/test_dask.py b/tests/test_dask.py index fb0a84d4..4200e5b7 100644 --- a/tests/test_dask.py +++ b/tests/test_dask.py @@ -1,3 +1,4 @@ +import sys from contextlib import contextmanager import numpy as np @@ -167,6 +168,10 @@ def test_sort_argsort_chunk_size(xp, func, shape, chunks): ) +@pytest.mark.skipif( + sys.version_info.major*100 + sys.version_info.minor < 312, + reason="dask interop requires numpy >= 3.12" +) @pytest.mark.parametrize("func", ["sort", "argsort"]) def test_sort_argsort_meta(xp, func): """Test meta-namespace other than numpy""" From 4bafa4cc8a455a301f3688fd3fa7404a4fe00974 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 11 Aug 2025 14:10:23 +0000 Subject: [PATCH 080/151] Bump actions/download-artifact from 4 to 5 in the actions group Bumps the actions group with 1 update: [actions/download-artifact](https://github.com/actions/download-artifact). Updates `actions/download-artifact` from 4 to 5 - [Release notes](https://github.com/actions/download-artifact/releases) - [Commits](https://github.com/actions/download-artifact/compare/v4...v5) --- updated-dependencies: - dependency-name: actions/download-artifact dependency-version: '5' dependency-type: direct:production update-type: version-update:semver-major dependency-group: actions ... Signed-off-by: dependabot[bot] --- .github/workflows/publish-package.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/publish-package.yml b/.github/workflows/publish-package.yml index 6d88066d..1e28689c 100644 --- a/.github/workflows/publish-package.yml +++ b/.github/workflows/publish-package.yml @@ -81,7 +81,7 @@ jobs: steps: - name: Download distribution artifact - uses: actions/download-artifact@v4 + uses: actions/download-artifact@v5 with: name: dist-artifact path: dist From edd9072c296827d0e4eccf02ae87920eb2481b9c Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 18 Aug 2025 13:07:35 +0000 Subject: [PATCH 081/151] Bump actions/checkout from 4 to 5 in the actions group Bumps the actions group with 1 update: [actions/checkout](https://github.com/actions/checkout). Updates `actions/checkout` from 4 to 5 - [Release notes](https://github.com/actions/checkout/releases) - [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md) - [Commits](https://github.com/actions/checkout/compare/v4...v5) --- updated-dependencies: - dependency-name: actions/checkout dependency-version: '5' dependency-type: direct:production update-type: version-update:semver-major dependency-group: actions ... Signed-off-by: dependabot[bot] --- .github/workflows/array-api-tests.yml | 4 ++-- .github/workflows/docs-build.yml | 2 +- .github/workflows/docs-deploy.yml | 2 +- .github/workflows/publish-package.yml | 2 +- .github/workflows/ruff.yml | 2 +- .github/workflows/tests.yml | 2 +- 6 files changed, 7 insertions(+), 7 deletions(-) diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index e832f870..5c3cc7d9 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -49,12 +49,12 @@ jobs: steps: - name: Checkout array-api-compat - uses: actions/checkout@v4 + uses: actions/checkout@v5 with: path: array-api-compat - name: Checkout array-api-tests - uses: actions/checkout@v4 + uses: actions/checkout@v5 with: repository: data-apis/array-api-tests submodules: 'true' diff --git a/.github/workflows/docs-build.yml b/.github/workflows/docs-build.yml index 34b9cbc6..778d20e2 100644 --- a/.github/workflows/docs-build.yml +++ b/.github/workflows/docs-build.yml @@ -6,7 +6,7 @@ jobs: docs-build: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - uses: actions/setup-python@v5 - name: Install Dependencies run: | diff --git a/.github/workflows/docs-deploy.yml b/.github/workflows/docs-deploy.yml index ed90b29d..42a3598f 100644 --- a/.github/workflows/docs-deploy.yml +++ b/.github/workflows/docs-deploy.yml @@ -11,7 +11,7 @@ jobs: environment: name: docs-deploy steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - name: Download Artifact uses: dawidd6/action-download-artifact@v11 with: diff --git a/.github/workflows/publish-package.yml b/.github/workflows/publish-package.yml index 1e28689c..03cae174 100644 --- a/.github/workflows/publish-package.yml +++ b/.github/workflows/publish-package.yml @@ -30,7 +30,7 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 with: fetch-depth: 0 diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml index a9f0fd4b..68f68a14 100644 --- a/.github/workflows/ruff.yml +++ b/.github/workflows/ruff.yml @@ -5,7 +5,7 @@ jobs: runs-on: ubuntu-latest continue-on-error: true steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - name: Install Python uses: actions/setup-python@v5 with: diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index c995b370..d2e768eb 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -23,7 +23,7 @@ jobs: python-version: '3.13' steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} From c2b7a51c85d037fba4ea7dea7d0efe74a13bb550 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 8 Sep 2025 09:12:57 +0000 Subject: [PATCH 082/151] Bump the actions group with 2 updates Bumps the actions group with 2 updates: [actions/setup-python](https://github.com/actions/setup-python) and [pypa/gh-action-pypi-publish](https://github.com/pypa/gh-action-pypi-publish). Updates `actions/setup-python` from 5 to 6 - [Release notes](https://github.com/actions/setup-python/releases) - [Commits](https://github.com/actions/setup-python/compare/v5...v6) Updates `pypa/gh-action-pypi-publish` from 1.12.4 to 1.13.0 - [Release notes](https://github.com/pypa/gh-action-pypi-publish/releases) - [Commits](https://github.com/pypa/gh-action-pypi-publish/compare/v1.12.4...v1.13.0) --- updated-dependencies: - dependency-name: actions/setup-python dependency-version: '6' dependency-type: direct:production update-type: version-update:semver-major dependency-group: actions - dependency-name: pypa/gh-action-pypi-publish dependency-version: 1.13.0 dependency-type: direct:production update-type: version-update:semver-minor dependency-group: actions ... Signed-off-by: dependabot[bot] --- .github/workflows/array-api-tests.yml | 2 +- .github/workflows/docs-build.yml | 2 +- .github/workflows/publish-package.yml | 6 +++--- .github/workflows/ruff.yml | 2 +- .github/workflows/tests.yml | 2 +- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index 5c3cc7d9..e3c0c9e0 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -61,7 +61,7 @@ jobs: path: array-api-tests - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: ${{ matrix.python-version }} diff --git a/.github/workflows/docs-build.yml b/.github/workflows/docs-build.yml index 778d20e2..305a9003 100644 --- a/.github/workflows/docs-build.yml +++ b/.github/workflows/docs-build.yml @@ -7,7 +7,7 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v5 - - uses: actions/setup-python@v5 + - uses: actions/setup-python@v6 - name: Install Dependencies run: | python -m pip install .[docs] diff --git a/.github/workflows/publish-package.yml b/.github/workflows/publish-package.yml index 03cae174..bbfb2e80 100644 --- a/.github/workflows/publish-package.yml +++ b/.github/workflows/publish-package.yml @@ -35,7 +35,7 @@ jobs: fetch-depth: 0 - name: Set up Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: '3.x' @@ -95,14 +95,14 @@ jobs: # if: >- # (github.event_name == 'push' && startsWith(github.ref, 'refs/tags')) # || (github.event_name == 'workflow_dispatch' && github.event.inputs.publish == 'true') - # uses: pypa/gh-action-pypi-publish@v1.12.4 + # uses: pypa/gh-action-pypi-publish@v1.13.0 # with: # repository-url: https://test.pypi.org/legacy/ # print-hash: true - name: Publish distribution 📦 to PyPI if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags') - uses: pypa/gh-action-pypi-publish@v1.12.4 + uses: pypa/gh-action-pypi-publish@v1.13.0 with: print-hash: true diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml index 68f68a14..4a2ffcff 100644 --- a/.github/workflows/ruff.yml +++ b/.github/workflows/ruff.yml @@ -7,7 +7,7 @@ jobs: steps: - uses: actions/checkout@v5 - name: Install Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: "3.11" - name: Install dependencies diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index d2e768eb..cfbb875f 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -24,7 +24,7 @@ jobs: steps: - uses: actions/checkout@v5 - - uses: actions/setup-python@v5 + - uses: actions/setup-python@v6 with: python-version: ${{ matrix.python-version }} - name: Install Dependencies From d794015c6f919267e8274a619f82b64380cd5a5a Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 8 Sep 2025 19:36:49 +0200 Subject: [PATCH 083/151] ENH: use torch.clamp for wrapped_torch.clip Otherwise, the version which emulates "clip" fails with torch.vmap (see gh-350) --- array_api_compat/torch/_aliases.py | 33 +++++++++++++++++++++++++++++- tests/test_torch.py | 12 +++++++++++ 2 files changed, 44 insertions(+), 1 deletion(-) diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 91161d24..af3dffc5 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -220,7 +220,6 @@ def min(x: Array, /, *, axis: int | tuple[int, ...] |None = None, keepdims: bool return torch.clone(x) return torch.amin(x, axis, keepdims=keepdims) -clip = get_xp(torch)(_aliases.clip) unstack = get_xp(torch)(_aliases.unstack) cumulative_sum = get_xp(torch)(_aliases.cumulative_sum) cumulative_prod = get_xp(torch)(_aliases.cumulative_prod) @@ -808,6 +807,38 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array: return torch.take_along_dim(x, indices, dim=axis) +def clip( + x: Array, + /, + min: int | float | Array | None = None, + max: int | float | Array | None = None, + **kwargs +) -> Array: + def _isscalar(a: object): + return isinstance(a, int | float) or a is None + + # cf clip in common/_aliases.py + if not x.is_floating_point(): + if type(min) is int and min <= torch.iinfo(x.dtype).min: + min = None + if type(max) is int and max >= torch.iinfo(x.dtype).max: + max = None + + if min is None and max is None: + return torch.clone(x) + + min_is_scalar = _isscalar(min) + max_is_scalar = _isscalar(max) + + if min is not None and max is not None: + if min_is_scalar and not max_is_scalar: + min = torch.as_tensor(min, dtype=x.dtype, device=x.device) + if max_is_scalar and not min_is_scalar: + max = torch.as_tensor(max, dtype=x.dtype, device=x.device) + + return torch.clamp(x, min, max, **kwargs) + + def sign(x: Array, /) -> Array: # torch sign() does not support complex numbers and does not propagate # nans. See https://github.com/data-apis/array-api-compat/issues/136 diff --git a/tests/test_torch.py b/tests/test_torch.py index 7adb4ab3..b3445a0e 100644 --- a/tests/test_torch.py +++ b/tests/test_torch.py @@ -102,6 +102,18 @@ def test_gh_273(self, default_dt, dtype_a, dtype_b): torch.set_default_dtype(prev_default) +def test_clip_vmap(): + # https://github.com/data-apis/array-api-compat/issues/350 + def apply_clip_compat(a): + return xp.clip(a, min=0, max=30) + + a = xp.asarray([[5.1, 2.0, 64.1, -1.5]]) + + ref = apply_clip_compat(a) + v1 = torch.vmap(apply_clip_compat) + assert xp.all(v1(a) == ref) + + def test_meshgrid(): """Verify that array_api_compat.torch.meshgrid defaults to indexing='xy'.""" From 85cf2285ac56230cf2c79c30ffc8a4727a057dc6 Mon Sep 17 00:00:00 2001 From: Arthur Date: Sun, 19 Oct 2025 19:48:09 +0200 Subject: [PATCH 084/151] FIX: Wrap torch.argsort to set stable=True by default --- array_api_compat/torch/_aliases.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 91161d24..7810e057 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -241,6 +241,21 @@ def sort( ) -> Array: return torch.sort(x, dim=axis, descending=descending, stable=stable, **kwargs).values + +# Wrap torch.argsort to set stable=True by default +def argsort( + x: Array, + /, + *, + axis: int = -1, + descending: bool = False, + stable: bool = True, + **kwargs: object, +) -> Array: + + return torch.argsort(x, dim=axis, descending=descending, stable=stable, **kwargs) + + def _normalize_axes(axis, ndim): axes = [] if ndim == 0 and axis: From 31e65041f2605040a91f9d56ff6e7ef8fb671daf Mon Sep 17 00:00:00 2001 From: Arthur Lacote Date: Sun, 19 Oct 2025 19:52:16 +0200 Subject: [PATCH 085/151] Apply suggestion from @Copilot Remove the empty line with trailing whitespace inside the function body. This line serves no purpose and should be deleted. Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- array_api_compat/torch/_aliases.py | 1 - 1 file changed, 1 deletion(-) diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 7810e057..715182a1 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -252,7 +252,6 @@ def argsort( stable: bool = True, **kwargs: object, ) -> Array: - return torch.argsort(x, dim=axis, descending=descending, stable=stable, **kwargs) From 1233b7bf65f382fa5153e8c9190692e67cdd89cc Mon Sep 17 00:00:00 2001 From: Arthur Date: Sun, 19 Oct 2025 19:58:18 +0200 Subject: [PATCH 086/151] fix linting --- array_api_compat/torch/fft.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_compat/torch/fft.py b/array_api_compat/torch/fft.py index 76342980..f11b3eb5 100644 --- a/array_api_compat/torch/fft.py +++ b/array_api_compat/torch/fft.py @@ -3,7 +3,7 @@ from collections.abc import Sequence from typing import Literal -import torch +import torch # noqa: F401 import torch.fft from ._typing import Array From 65412edbe1a6aa32913ce4cf2fb7f29552937b70 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 27 Oct 2025 09:31:56 +0000 Subject: [PATCH 087/151] Bump the actions group with 2 updates Bumps the actions group with 2 updates: [actions/upload-artifact](https://github.com/actions/upload-artifact) and [actions/download-artifact](https://github.com/actions/download-artifact). Updates `actions/upload-artifact` from 4 to 5 - [Release notes](https://github.com/actions/upload-artifact/releases) - [Commits](https://github.com/actions/upload-artifact/compare/v4...v5) Updates `actions/download-artifact` from 5 to 6 - [Release notes](https://github.com/actions/download-artifact/releases) - [Commits](https://github.com/actions/download-artifact/compare/v5...v6) --- updated-dependencies: - dependency-name: actions/upload-artifact dependency-version: '5' dependency-type: direct:production update-type: version-update:semver-major dependency-group: actions - dependency-name: actions/download-artifact dependency-version: '6' dependency-type: direct:production update-type: version-update:semver-major dependency-group: actions ... Signed-off-by: dependabot[bot] --- .github/workflows/docs-build.yml | 2 +- .github/workflows/publish-package.yml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/docs-build.yml b/.github/workflows/docs-build.yml index 305a9003..1fd6f9d5 100644 --- a/.github/workflows/docs-build.yml +++ b/.github/workflows/docs-build.yml @@ -16,7 +16,7 @@ jobs: cd docs make html - name: Upload Artifact - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v5 with: name: docs-build path: docs/_build/html diff --git a/.github/workflows/publish-package.yml b/.github/workflows/publish-package.yml index bbfb2e80..485295c1 100644 --- a/.github/workflows/publish-package.yml +++ b/.github/workflows/publish-package.yml @@ -60,7 +60,7 @@ jobs: run: python -m zipfile --list dist/array_api_compat-*.whl - name: Upload distribution artifact - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v5 with: name: dist-artifact path: dist @@ -81,7 +81,7 @@ jobs: steps: - name: Download distribution artifact - uses: actions/download-artifact@v5 + uses: actions/download-artifact@v6 with: name: dist-artifact path: dist From 1fafddae1633f1141483c21877beff8f9e9729b5 Mon Sep 17 00:00:00 2001 From: Arthur Date: Tue, 28 Oct 2025 13:47:24 +0100 Subject: [PATCH 088/151] added to aliases --- array_api_compat/torch/_aliases.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 715182a1..23dafde9 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -851,9 +851,9 @@ def meshgrid(*arrays: Array, indexing: Literal['xy', 'ij'] = 'xy') -> list[Array 'equal', 'floor_divide', 'greater', 'greater_equal', 'hypot', 'less', 'less_equal', 'logaddexp', 'maximum', 'minimum', 'multiply', 'not_equal', 'pow', 'remainder', 'subtract', 'max', - 'min', 'clip', 'unstack', 'cumulative_sum', 'cumulative_prod', 'sort', 'prod', 'sum', - 'any', 'all', 'mean', 'std', 'var', 'concat', 'squeeze', - 'broadcast_to', 'flip', 'roll', 'nonzero', 'where', 'reshape', + 'min', 'clip', 'unstack', 'cumulative_sum', 'cumulative_prod', 'sort', + 'argsort', 'prod', 'sum', 'any', 'all', 'mean', 'std', 'var', 'concat', + 'squeeze', 'broadcast_to', 'flip', 'roll', 'nonzero', 'where', 'reshape', 'arange', 'eye', 'linspace', 'full', 'ones', 'zeros', 'empty', 'tril', 'triu', 'expand_dims', 'astype', 'broadcast_arrays', 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult', From 7cf0b798352170d260a1b4a9623a770daa36c703 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Tue, 28 Oct 2025 17:45:29 +0100 Subject: [PATCH 089/151] TST: add a test for torch.argsort defaulting to stable=True cross-ref https://github.com/data-apis/array-api-compat/pull/356 which wrapped torch.argsort to fix the default, and https://github.com/data-apis/array-api-tests/pull/390 which made a matching change in the array-api-test suite. --- tests/test_torch.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/test_torch.py b/tests/test_torch.py index f661a272..c8619565 100644 --- a/tests/test_torch.py +++ b/tests/test_torch.py @@ -130,3 +130,13 @@ def test_meshgrid(): assert Y.shape == Y_ij.shape assert xp.all(Y == Y_ij) + +def test_argsort_stable(): + """Verify that argsort defaults to a stable sort.""" + # Bare pytorch defaults to an unstable sort, and the array_api_compat wrapper + # enforces the stable=True default. + # cf https://github.com/data-apis/array-api-compat/pull/356 and + # https://github.com/data-apis/array-api-tests/pull/390#issuecomment-3452868329 + + t = xp.zeros(50) # should be >16 + assert xp.all(xp.argsort(t) == xp.arange(50)) From a60869af172d8575ae31a8bff6653b7de26bec79 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sat, 22 Nov 2025 16:33:12 +0100 Subject: [PATCH 090/151] ENH: torch: allow negative indices in take_along_axis --- array_api_compat/torch/_aliases.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index d3857755..2903ac3e 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -819,7 +819,11 @@ def take(x: Array, indices: Array, /, *, axis: int | None = None, **kwargs: obje def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array: - return torch.take_along_dim(x, indices, dim=axis) + return torch.take_along_dim( + x, + torch.where(indices < 0, indices + x.shape[axis], indices), + dim=axis + ) def sign(x: Array, /) -> Array: From 92662a6f41e6c7c5d168a5c231e66c7a63c56674 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sat, 22 Nov 2025 17:23:05 +0100 Subject: [PATCH 091/151] ENH: torch: allow negative indices in take() --- array_api_compat/torch/_aliases.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 2903ac3e..7fc1194e 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -815,7 +815,12 @@ def take(x: Array, indices: Array, /, *, axis: int | None = None, **kwargs: obje if x.ndim != 1: raise ValueError("axis must be specified when ndim > 1") axis = 0 - return torch.index_select(x, axis, indices, **kwargs) + return torch.index_select( + x, + axis, + torch.where(indices < 0, indices + x.shape[axis], indices), + **kwargs + ) def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array: From 4355ab819c3c24c15861e3891bcdd58899f12421 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sun, 23 Nov 2025 10:35:35 +0100 Subject: [PATCH 092/151] MAINT: link to pytorch issue for negative indices --- array_api_compat/torch/_aliases.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 7fc1194e..4e8533f9 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -815,6 +815,8 @@ def take(x: Array, indices: Array, /, *, axis: int | None = None, **kwargs: obje if x.ndim != 1: raise ValueError("axis must be specified when ndim > 1") axis = 0 + # torch does not support negative indices, + # see https://github.com/pytorch/pytorch/issues/146211 return torch.index_select( x, axis, @@ -824,6 +826,8 @@ def take(x: Array, indices: Array, /, *, axis: int | None = None, **kwargs: obje def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array: + # torch does not support negative indices, + # see https://github.com/pytorch/pytorch/issues/146211 return torch.take_along_dim( x, torch.where(indices < 0, indices + x.shape[axis], indices), From 0d559ce7144f61e04d029ebaedceb9a3c2d21fa2 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 24 Nov 2025 09:10:38 +0000 Subject: [PATCH 093/151] Bump actions/checkout from 5 to 6 in the actions group Bumps the actions group with 1 update: [actions/checkout](https://github.com/actions/checkout). Updates `actions/checkout` from 5 to 6 - [Release notes](https://github.com/actions/checkout/releases) - [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md) - [Commits](https://github.com/actions/checkout/compare/v5...v6) --- updated-dependencies: - dependency-name: actions/checkout dependency-version: '6' dependency-type: direct:production update-type: version-update:semver-major dependency-group: actions ... Signed-off-by: dependabot[bot] --- .github/workflows/array-api-tests.yml | 4 ++-- .github/workflows/docs-build.yml | 2 +- .github/workflows/docs-deploy.yml | 2 +- .github/workflows/publish-package.yml | 2 +- .github/workflows/ruff.yml | 2 +- .github/workflows/tests.yml | 2 +- 6 files changed, 7 insertions(+), 7 deletions(-) diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index e3c0c9e0..8d78225c 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -49,12 +49,12 @@ jobs: steps: - name: Checkout array-api-compat - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: path: array-api-compat - name: Checkout array-api-tests - uses: actions/checkout@v5 + uses: actions/checkout@v6 with: repository: data-apis/array-api-tests submodules: 'true' diff --git a/.github/workflows/docs-build.yml b/.github/workflows/docs-build.yml index 1fd6f9d5..013e69db 100644 --- a/.github/workflows/docs-build.yml +++ b/.github/workflows/docs-build.yml @@ -6,7 +6,7 @@ jobs: docs-build: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - uses: actions/setup-python@v6 - name: Install Dependencies run: | diff --git a/.github/workflows/docs-deploy.yml b/.github/workflows/docs-deploy.yml index 42a3598f..fbd7c89a 100644 --- a/.github/workflows/docs-deploy.yml +++ b/.github/workflows/docs-deploy.yml @@ -11,7 +11,7 @@ jobs: environment: name: docs-deploy steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - name: Download Artifact uses: dawidd6/action-download-artifact@v11 with: diff --git a/.github/workflows/publish-package.yml b/.github/workflows/publish-package.yml index 485295c1..8710965c 100644 --- a/.github/workflows/publish-package.yml +++ b/.github/workflows/publish-package.yml @@ -30,7 +30,7 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 with: fetch-depth: 0 diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml index 4a2ffcff..6e838902 100644 --- a/.github/workflows/ruff.yml +++ b/.github/workflows/ruff.yml @@ -5,7 +5,7 @@ jobs: runs-on: ubuntu-latest continue-on-error: true steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - name: Install Python uses: actions/setup-python@v6 with: diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index cfbb875f..585304b1 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -23,7 +23,7 @@ jobs: python-version: '3.13' steps: - - uses: actions/checkout@v5 + - uses: actions/checkout@v6 - uses: actions/setup-python@v6 with: python-version: ${{ matrix.python-version }} From 609d0a076b798ea9027d8674de2097d512ffeb85 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sun, 30 Nov 2025 19:01:12 +0100 Subject: [PATCH 094/151] CI: add python 3.14 to the CI matrix --- .github/workflows/array-api-tests-numpy-dev.yml | 2 +- .github/workflows/array-api-tests-numpy-latest.yml | 2 +- .github/workflows/array-api-tests-torch.yml | 2 +- .github/workflows/tests.yml | 6 +++++- 4 files changed, 8 insertions(+), 4 deletions(-) diff --git a/.github/workflows/array-api-tests-numpy-dev.yml b/.github/workflows/array-api-tests-numpy-dev.yml index dec4c7ae..7a521360 100644 --- a/.github/workflows/array-api-tests-numpy-dev.yml +++ b/.github/workflows/array-api-tests-numpy-dev.yml @@ -9,7 +9,7 @@ jobs: package-name: numpy extra-requires: '--pre --extra-index https://pypi.anaconda.org/scientific-python-nightly-wheels/simple' xfails-file-extra: '-dev' - python-versions: '[''3.11'', ''3.13'']' + python-versions: '[''3.11'', ''3.13'', ''3.14'']' pytest-extra-args: -n 4 extra-env-vars: | ARRAY_API_TESTS_XFAIL_MARK=skip diff --git a/.github/workflows/array-api-tests-numpy-latest.yml b/.github/workflows/array-api-tests-numpy-latest.yml index 65bbc9a2..03e0e11e 100644 --- a/.github/workflows/array-api-tests-numpy-latest.yml +++ b/.github/workflows/array-api-tests-numpy-latest.yml @@ -7,7 +7,7 @@ jobs: uses: ./.github/workflows/array-api-tests.yml with: package-name: numpy - python-versions: '[''3.10'', ''3.13'']' + python-versions: '[''3.10'', ''3.13'', ''3.14'']' pytest-extra-args: -n 4 extra-env-vars: | ARRAY_API_TESTS_XFAIL_MARK=skip diff --git a/.github/workflows/array-api-tests-torch.yml b/.github/workflows/array-api-tests-torch.yml index 4b4b945e..d5cdfa72 100644 --- a/.github/workflows/array-api-tests-torch.yml +++ b/.github/workflows/array-api-tests-torch.yml @@ -11,5 +11,5 @@ jobs: extra-env-vars: | ARRAY_API_TESTS_SKIP_DTYPES=uint16,uint32,uint64 ARRAY_API_TESTS_XFAIL_MARK=skip - python-versions: '[''3.10'', ''3.13'']' + python-versions: '[''3.10'', ''3.13'', ''3.14'']' pytest-extra-args: -n 4 diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 585304b1..10894c2a 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -17,11 +17,15 @@ jobs: python-version: '3.10' - numpy-version: 'latest' python-version: '3.13' + - numpy-version: 'latest' + python-version: '3.14' - numpy-version: 'dev' python-version: '3.11' - numpy-version: 'dev' python-version: '3.13' - + - numpy-version: 'dev' + python-version: '3.14' + steps: - uses: actions/checkout@v6 - uses: actions/setup-python@v6 From 1ef5b07c5e636962c2c757fbd61a9c9e7a3ae627 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sun, 30 Nov 2025 20:17:31 +0100 Subject: [PATCH 095/151] CI: skip install sparse and ndonnx on py 3.14 --- .github/workflows/tests.yml | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 10894c2a..c27283ed 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -42,14 +42,20 @@ jobs: if [ "${{ matrix.numpy-version }}" == "dev" ]; then python -m pip install numpy --pre --extra-index-url https://pypi.anaconda.org/scientific-python-nightly-wheels/simple - python -m pip install dask[array] jax[cpu] sparse ndonnx + python -m pip install dask[array] jax[cpu] + if ["${{ matrix.python-version }}" != "3.14]; then + python -m pip install sparse ndonnx + fi elif [ "${{ matrix.numpy-version }}" == "1.22" ]; then python -m pip install 'numpy==1.22.*' elif [ "${{ matrix.numpy-version }}" == "1.26" ]; then python -m pip install 'numpy==1.26.*' else python -m pip install numpy - python -m pip install dask[array] jax[cpu] sparse ndonnx + python -m pip install dask[array] jax[cpu] + if ["${{ matrix.python-version }}" != "3.14]; then + python -m pip install sparse ndonnx + fi fi - name: Dump pip environment From da7d9ecf9e0d910ff259ef74b05f583af9e43b6f Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 15 Dec 2025 09:06:08 +0000 Subject: [PATCH 096/151] Bump the actions group with 2 updates Bumps the actions group with 2 updates: [actions/upload-artifact](https://github.com/actions/upload-artifact) and [actions/download-artifact](https://github.com/actions/download-artifact). Updates `actions/upload-artifact` from 5 to 6 - [Release notes](https://github.com/actions/upload-artifact/releases) - [Commits](https://github.com/actions/upload-artifact/compare/v5...v6) Updates `actions/download-artifact` from 6 to 7 - [Release notes](https://github.com/actions/download-artifact/releases) - [Commits](https://github.com/actions/download-artifact/compare/v6...v7) --- updated-dependencies: - dependency-name: actions/upload-artifact dependency-version: '6' dependency-type: direct:production update-type: version-update:semver-major dependency-group: actions - dependency-name: actions/download-artifact dependency-version: '7' dependency-type: direct:production update-type: version-update:semver-major dependency-group: actions ... Signed-off-by: dependabot[bot] --- .github/workflows/docs-build.yml | 2 +- .github/workflows/publish-package.yml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/docs-build.yml b/.github/workflows/docs-build.yml index 013e69db..91a68457 100644 --- a/.github/workflows/docs-build.yml +++ b/.github/workflows/docs-build.yml @@ -16,7 +16,7 @@ jobs: cd docs make html - name: Upload Artifact - uses: actions/upload-artifact@v5 + uses: actions/upload-artifact@v6 with: name: docs-build path: docs/_build/html diff --git a/.github/workflows/publish-package.yml b/.github/workflows/publish-package.yml index 8710965c..ce47aed5 100644 --- a/.github/workflows/publish-package.yml +++ b/.github/workflows/publish-package.yml @@ -60,7 +60,7 @@ jobs: run: python -m zipfile --list dist/array_api_compat-*.whl - name: Upload distribution artifact - uses: actions/upload-artifact@v5 + uses: actions/upload-artifact@v6 with: name: dist-artifact path: dist @@ -81,7 +81,7 @@ jobs: steps: - name: Download distribution artifact - uses: actions/download-artifact@v6 + uses: actions/download-artifact@v7 with: name: dist-artifact path: dist From b61e9c3fbc55e1fb66a63b4d4f333fb04dbd3879 Mon Sep 17 00:00:00 2001 From: Martin Schuck <57562633+amacati@users.noreply.github.com> Date: Sat, 27 Dec 2025 19:15:40 +0100 Subject: [PATCH 097/151] BUG: Fix `is_jax_array` for `jax>=0.8.2` (#369) * Fix is_jax_array for jax>=0.8.2 * Skip jax test if not installed * Fix and test array_api_obj, is_writable_array, is_lazy_array * Add comments on jax.core.Tracer detection limitations --- array_api_compat/common/_helpers.py | 15 ++++++++++++++- tests/test_jax.py | 25 ++++++++++++++++++++++--- 2 files changed, 36 insertions(+), 4 deletions(-) diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 37f31ec2..8194a083 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -235,7 +235,17 @@ def is_jax_array(x: object) -> TypeIs[jax.Array]: is_pydata_sparse_array """ cls = cast(Hashable, type(x)) - return _issubclass_fast(cls, "jax", "Array") or _is_jax_zero_gradient_array(x) + # We test for jax.core.Tracer here to identify jax arrays during jit tracing. From jax 0.8.2 on, + # tracers are not a subclass of jax.Array anymore. Note that tracers can also represent + # non-array values and a fully correct implementation would need to use isinstance checks. Since + # we use hash-based caching with type names as keys, we cannot use instance checks without + # losing performance here. For more information, see + # https://github.com/data-apis/array-api-compat/pull/369 and the corresponding issue. + return ( + _issubclass_fast(cls, "jax", "Array") + or _issubclass_fast(cls, "jax.core", "Tracer") + or _is_jax_zero_gradient_array(x) + ) def is_pydata_sparse_array(x: object) -> TypeIs[sparse.SparseArray]: @@ -296,6 +306,7 @@ def _is_array_api_cls(cls: type) -> bool: or _issubclass_fast(cls, "sparse", "SparseArray") # TODO: drop support for jax<0.4.32 which didn't have __array_namespace__ or _issubclass_fast(cls, "jax", "Array") + or _issubclass_fast(cls, "jax.core", "Tracer") # see is_jax_array for limitations ) @@ -934,6 +945,7 @@ def _is_writeable_cls(cls: type) -> bool | None: if ( _issubclass_fast(cls, "numpy", "generic") or _issubclass_fast(cls, "jax", "Array") + or _issubclass_fast(cls, "jax.core", "Tracer") # see is_jax_array for limitations or _issubclass_fast(cls, "sparse", "SparseArray") ): return False @@ -973,6 +985,7 @@ def _is_lazy_cls(cls: type) -> bool | None: return False if ( _issubclass_fast(cls, "jax", "Array") + or _issubclass_fast(cls, "jax.core", "Tracer") # see is_jax_array for limitations or _issubclass_fast(cls, "dask.array", "Array") or _issubclass_fast(cls, "ndonnx", "Array") ): diff --git a/tests/test_jax.py b/tests/test_jax.py index 285958d4..322d0223 100644 --- a/tests/test_jax.py +++ b/tests/test_jax.py @@ -1,7 +1,14 @@ from numpy.testing import assert_equal import pytest -from array_api_compat import device, to_device +from array_api_compat import ( + device, + to_device, + is_jax_array, + is_lazy_array, + is_array_api_obj, + is_writeable_array, +) try: import jax @@ -13,7 +20,7 @@ @pytest.mark.parametrize( - "func", + "func", [ lambda x: jnp.zeros(1, device=device(x)), lambda x: jnp.zeros_like(jnp.ones(1, device=device(x))), @@ -26,7 +33,7 @@ ), ), lambda x: to_device(jnp.zeros(1), device(x)), - ] + ], ) def test_device_jit(func): # Test work around to https://github.com/jax-ml/jax/issues/26000 @@ -36,3 +43,15 @@ def test_device_jit(func): x = jnp.ones(1) assert_equal(func(x), jnp.asarray([0])) assert_equal(jax.jit(func)(x), jnp.asarray([0])) + + +def test_inside_jit(): + # Test if jax arrays are handled correctly inside jax.jit. + # Jax tracers are not a subclass of jax.Array from 0.8.2 on. We explicitly test that + # tracers are handled appropriately. For limitations, see is_jax_array() docstring. + # Reference issue: https://github.com/data-apis/array-api-compat/issues/368 + x = jnp.asarray([1, 2, 3]) + assert jax.jit(is_jax_array)(x) + assert jax.jit(is_array_api_obj)(x) + assert not jax.jit(is_writeable_array)(x) + assert jax.jit(is_lazy_array)(x) From 946ce4ad77968b94e93594c79653162426ec3224 Mon Sep 17 00:00:00 2001 From: Lucas Colley Date: Sun, 28 Dec 2025 11:22:23 +0000 Subject: [PATCH 098/151] REL: prepare 1.13.0 (#370) --- array_api_compat/__init__.py | 2 +- docs/changelog.md | 29 +++++++++++++++++++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/array_api_compat/__init__.py b/array_api_compat/__init__.py index a00e8cbc..4abca400 100644 --- a/array_api_compat/__init__.py +++ b/array_api_compat/__init__.py @@ -17,6 +17,6 @@ this implementation for the default when working with NumPy arrays. """ -__version__ = '1.13.0.dev0' +__version__ = '1.13.0' from .common import * # noqa: F401, F403 diff --git a/docs/changelog.md b/docs/changelog.md index 6f6c1251..fe07c9e0 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -1,5 +1,34 @@ # Changelog +## 1.13.0 (2025-12-28) + + +### Major changes + +- Support for Python 3.14 has been added. +- Symbols exported in public namespaces have been reviewed and adjusted. +- `torch.take` and `torch.take_along_axis` now support negative indices. +- `torch.meshgrid` now correctly processes the `indexing` argument. +- View/copy semantics are now observed for the `ceil`, `floor`, and `trunc` functions. + +### Minor changes + +- `array_namespace` has been sped up via caching. +- The `stable` parameter of `torch.argsort` now defaults to `True`, per the standard. +- Type annotations have seen progress. +- `is_jax_array` has been adjusted for compatibility with `jax>=0.8.2` + + +The following users contributed to this release: + +Evgeni Burovski, +Guido Imperiale, +Lucas Colley, +Arthur Lacote, +Martin Schuck, +Matt Haberland. + + ## 1.12.0 (2025-05-13) From 3f927f11fe03ca0a56634db57056830351b67535 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 29 Dec 2025 09:05:04 +0000 Subject: [PATCH 099/151] Bump dawidd6/action-download-artifact from 11 to 12 in the actions group Bumps the actions group with 1 update: [dawidd6/action-download-artifact](https://github.com/dawidd6/action-download-artifact). Updates `dawidd6/action-download-artifact` from 11 to 12 - [Release notes](https://github.com/dawidd6/action-download-artifact/releases) - [Commits](https://github.com/dawidd6/action-download-artifact/compare/v11...v12) --- updated-dependencies: - dependency-name: dawidd6/action-download-artifact dependency-version: '12' dependency-type: direct:production update-type: version-update:semver-major dependency-group: actions ... Signed-off-by: dependabot[bot] --- .github/workflows/docs-deploy.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/docs-deploy.yml b/.github/workflows/docs-deploy.yml index fbd7c89a..76c76e31 100644 --- a/.github/workflows/docs-deploy.yml +++ b/.github/workflows/docs-deploy.yml @@ -13,7 +13,7 @@ jobs: steps: - uses: actions/checkout@v6 - name: Download Artifact - uses: dawidd6/action-download-artifact@v11 + uses: dawidd6/action-download-artifact@v12 with: workflow: docs-build.yml name: docs-build From 9ef7c4692dd8c7c7a85d491dd370858244df58cb Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Fri, 9 Jan 2026 00:01:21 +0100 Subject: [PATCH 100/151] Bump version to 1.14.0.dev0 --- array_api_compat/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_compat/__init__.py b/array_api_compat/__init__.py index 4abca400..a28101d6 100644 --- a/array_api_compat/__init__.py +++ b/array_api_compat/__init__.py @@ -17,6 +17,6 @@ this implementation for the default when working with NumPy arrays. """ -__version__ = '1.13.0' +__version__ = '1.14.0.dev0' from .common import * # noqa: F401, F403 From a9704103d46a5e86c7a38ffbec071c9ba196969f Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Fri, 9 Jan 2026 13:29:42 +0100 Subject: [PATCH 101/151] BUG: torch: fix up clip --- array_api_compat/torch/_aliases.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index e40183d8..4b232f84 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -4,6 +4,7 @@ from functools import reduce as _reduce, wraps as _wraps from builtins import all as _builtin_all, any as _builtin_any from typing import Any, Literal +import math import torch @@ -857,13 +858,24 @@ def _isscalar(a: object): min_is_scalar = _isscalar(min) max_is_scalar = _isscalar(max) - if min is not None and max is not None: - if min_is_scalar and not max_is_scalar: - min = torch.as_tensor(min, dtype=x.dtype, device=x.device) - if max_is_scalar and not min_is_scalar: - max = torch.as_tensor(max, dtype=x.dtype, device=x.device) + if min_is_scalar and max_is_scalar: + if (min is not None and math.isnan(min)) or (max is not None and math.isnan(max)): + # edge case: torch.clamp(torch.zeros(1), float('nan')) -> tensor(0.) + # https://github.com/pytorch/pytorch/issues/172067 + return torch.full_like(x, fill_value=torch.nan) + return torch.clamp(x, min, max, **kwargs) - return torch.clamp(x, min, max, **kwargs) + # pytorch has (tensor, tensor, tensor) and (tensor, scalar, scalar) signatures, + # but does not accept (tensor, scalar, tensor) + a_min = min + if min is not None and min_is_scalar: + a_min = torch.as_tensor(min, dtype=x.dtype, device=x.device) + + a_max = max + if max is not None and max_is_scalar: + a_max = torch.as_tensor(max, dtype=x.dtype, device=x.device) + + return torch.clamp(x, a_min, a_max, **kwargs) def sign(x: Array, /) -> Array: From 020e167b234223258af9bc3989dbe287237095fc Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Fri, 9 Jan 2026 14:53:27 +0100 Subject: [PATCH 102/151] DOC: Add note to update version attribute in releasing process (#377) --- docs/dev/releasing.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/dev/releasing.md b/docs/dev/releasing.md index 1ee17709..f6fca85c 100644 --- a/docs/dev/releasing.md +++ b/docs/dev/releasing.md @@ -100,6 +100,11 @@ docs update (the docs are published automatically from the sources on `main`). +- [ ] **Bump the `__version__` attribute in `__init__.py`** + + After an M.N.0 release, further development is towards version `M.(N+1).0`, thus the main branch's + version is `M.(N+1).0.dev0`. + - [ ] **Update conda-forge.** After the PyPI package is published, the conda-forge bot should update the From c1b16c89342305ec45822a303bf0f91286662b3c Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Wed, 21 Jan 2026 21:19:36 +0100 Subject: [PATCH 103/151] TST: test that clip(x) returns a copy not a view The spec only says that "If both min and max are None, the elements of the returned array must equal the respective elements in x" Bare NumPy 2.x and CuPy 13.x return copes: >>> x = np.arange(8); np.may_share_memory(x, np.clip(x)) False Thus assume that all wrapped libraries should return a copy, too. Add a test to this effect. --- tests/test_copies_or_views.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/test_copies_or_views.py b/tests/test_copies_or_views.py index ec8995f7..1e564694 100644 --- a/tests/test_copies_or_views.py +++ b/tests/test_copies_or_views.py @@ -62,3 +62,15 @@ def test_view_or_copy(inputs, xp_name): is_view_wrapped = is_view(wrapped_func, a1, value) assert is_view_bare == is_view_wrapped + + +@pytest.mark.parametrize('xp_name', wrapped_libraries + ['array_api_strict']) +def test_clip_none(xp_name): + xp = import_(xp_name, wrapper=True) + + if xp_name == 'array_api_strict' and xp.__version__ < "2.5": + # https://github.com/data-apis/array-api-strict/pull/180 + pytest.xfail("clip(x) was only fixed in -strict == 2.5") + + x = xp.arange(8) + assert not is_view(xp.clip, x, 42) From 8e98058e87f24f6a70db2b2d15bbb3ec828b1284 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 19 Jan 2026 14:05:44 +0100 Subject: [PATCH 104/151] TST: update -xfails for complex special cases --- cupy-xfails.txt | 11 +++++++++++ dask-xfails.txt | 12 ++++++++++++ numpy-1-22-xfails.txt | 13 +++++++++++++ numpy-1-26-xfails.txt | 16 ++++++++++++++++ numpy-dev-xfails.txt | 12 ++++++++++++ numpy-xfails.txt | 13 +++++++++++++ torch-xfails.txt | 15 +++++++++++++++ 7 files changed, 92 insertions(+) diff --git a/cupy-xfails.txt b/cupy-xfails.txt index 0a91cafe..25ff654c 100644 --- a/cupy-xfails.txt +++ b/cupy-xfails.txt @@ -176,6 +176,17 @@ array_api_tests/test_special_cases.py::test_unary[tan(x_i is -0) -> -0] array_api_tests/test_special_cases.py::test_unary[tanh(x_i is -0) -> -0] array_api_tests/test_special_cases.py::test_unary[trunc(x_i is -0) -> -0] +# complex cases +array_api_tests/test_special_cases.py::test_unary[acosh(real(x_i) is +0 and imag(x_i) is NaN) -> NaN \xb1 \u03c0j/2] +array_api_tests/test_special_cases.py::test_unary[log(real(x_i) is -0 and imag(x_i) is +0) -> -infinity + \u03c0j] +array_api_tests/test_special_cases.py::test_unary[log((real(x_i) is +infinity or real(x_i) == -infinity) and imag(x_i) is NaN) -> +infinity + NaN j] +array_api_tests/test_special_cases.py::test_unary[log(real(x_i) is NaN and imag(x_i) is +infinity) -> +infinity + NaN j] +array_api_tests/test_special_cases.py::test_unary[log1p((real(x_i) is +infinity or real(x_i) == -infinity) and imag(x_i) is NaN) -> +infinity + NaN j] +array_api_tests/test_special_cases.py::test_unary[log1p(real(x_i) is NaN and imag(x_i) is +infinity) -> +infinity + NaN j] +array_api_tests/test_special_cases.py::test_unary[sqrt(real(x_i) is +infinity and isfinite(imag(x_i)) and imag(x_i) > 0) -> +0 + infinity j] +array_api_tests/test_special_cases.py::test_unary[tanh(real(x_i) is +0 and imag(x_i) is +infinity) -> +0 + NaN j] +array_api_tests/test_special_cases.py::test_unary[tanh(real(x_i) is +0 and imag(x_i) is NaN) -> +0 + NaN j] + # CuPy gives the wrong shape for n-dim fft funcs. See # https://github.com/data-apis/array-api-compat/pull/78#issuecomment-1984527870 array_api_tests/test_fft.py::test_fftn diff --git a/dask-xfails.txt b/dask-xfails.txt index 3efb4f96..5ce24254 100644 --- a/dask-xfails.txt +++ b/dask-xfails.txt @@ -148,3 +148,15 @@ array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] + +# complex cases +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is +infinity and imag(x_i) is +0) -> +infinity + 0j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is -infinity and imag(x_i) is +infinity) -> -1 + 0j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is +infinity and imag(x_i) is +infinity) -> infinity + NaN j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is -infinity and imag(x_i) is NaN) -> -1 + 0j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is +infinity and imag(x_i) is NaN) -> infinity + NaN j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is NaN and imag(x_i) is +0) -> NaN + 0j] +array_api_tests/test_special_cases.py::test_unary[sqrt(real(x_i) is +infinity and isfinite(imag(x_i)) and imag(x_i) > 0) -> +0 + infinity j] +array_api_tests/test_special_cases.py::test_unary[tanh(real(x_i) is +infinity and isfinite(imag(x_i)) and imag(x_i) > 0) -> 1 + 0j] +array_api_tests/test_special_cases.py::test_unary[expm1((real(x_i) is +0 or real(x_i) == -0) and imag(x_i) is +0) -> 0 + 0j] +array_api_tests/test_special_cases.py::test_unary[acosh(real(x_i) is +0 and imag(x_i) is NaN) -> NaN \xb1 \u03c0j/2] diff --git a/numpy-1-22-xfails.txt b/numpy-1-22-xfails.txt index 5df1b6d7..ab9e47a8 100644 --- a/numpy-1-22-xfails.txt +++ b/numpy-1-22-xfails.txt @@ -173,3 +173,16 @@ array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] + +# complex special cases +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is +infinity and imag(x_i) is +0) -> +infinity + 0j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is -infinity and imag(x_i) is +infinity) -> -1 + 0j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is +infinity and imag(x_i) is +infinity) -> infinity + NaN j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is -infinity and imag(x_i) is NaN) -> -1 + 0j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is +infinity and imag(x_i) is NaN) -> infinity + NaN j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is NaN and imag(x_i) is +0) -> NaN + 0j] +array_api_tests/test_special_cases.py::test_unary[expm1((real(x_i) is +0 or real(x_i) == -0) and imag(x_i) is +0) -> 0 + 0j] + +array_api_tests/test_special_cases.py::test_unary[sqrt(real(x_i) is +infinity and isfinite(imag(x_i)) and imag(x_i) > 0) -> +0 + infinity j] +array_api_tests/test_special_cases.py::test_unary[acosh(real(x_i) is +0 and imag(x_i) is NaN) -> NaN \xb1 \u03c0j/2] +array_api_tests/test_special_cases.py::test_unary[tanh(real(x_i) is +infinity and isfinite(imag(x_i)) and imag(x_i) > 0) -> 1 + 0j] diff --git a/numpy-1-26-xfails.txt b/numpy-1-26-xfails.txt index 98cb9f6c..060bacc9 100644 --- a/numpy-1-26-xfails.txt +++ b/numpy-1-26-xfails.txt @@ -71,3 +71,19 @@ array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] + +# complex special cases +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is +infinity and imag(x_i) is +0) -> +infinity + 0j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is -infinity and imag(x_i) is +infinity) -> -1 + 0j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is +infinity and imag(x_i) is +infinity) -> infinity + NaN j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is -infinity and imag(x_i) is NaN) -> -1 + 0j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is +infinity and imag(x_i) is NaN) -> infinity + NaN j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is NaN and imag(x_i) is +0) -> NaN + 0j] +array_api_tests/test_special_cases.py::test_unary[expm1((real(x_i) is +0 or real(x_i) == -0) and imag(x_i) is +0) -> 0 + 0j] +array_api_tests/test_special_cases.py::test_unary[sqrt(real(x_i) is +infinity and isfinite(imag(x_i)) and imag(x_i) > 0) -> +0 + infinity j] +array_api_tests/test_special_cases.py::test_unary[acosh(real(x_i) is +0 and imag(x_i) is NaN) -> NaN \xb1 \u03c0j/2] +array_api_tests/test_special_cases.py::test_unary[tanh(real(x_i) is +infinity and isfinite(imag(x_i)) and imag(x_i) > 0) -> 1 + 0j] + + + + diff --git a/numpy-dev-xfails.txt b/numpy-dev-xfails.txt index 972d2346..25552853 100644 --- a/numpy-dev-xfails.txt +++ b/numpy-dev-xfails.txt @@ -38,3 +38,15 @@ array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i < 0) -> +infinity] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] + +# complex special cases +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is +infinity and imag(x_i) is +0) -> +infinity + 0j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is -infinity and imag(x_i) is +infinity) -> -1 + 0j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is +infinity and imag(x_i) is +infinity) -> infinity + NaN j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is -infinity and imag(x_i) is NaN) -> -1 + 0j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is +infinity and imag(x_i) is NaN) -> infinity + NaN j] +array_api_tests/test_special_cases.py::test_unary[expm1((real(x_i) is +0 or real(x_i) == -0) and imag(x_i) is +0) -> 0 + 0j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is NaN and imag(x_i) is +0) -> NaN + 0j] +array_api_tests/test_special_cases.py::test_unary[sqrt(real(x_i) is +infinity and isfinite(imag(x_i)) and imag(x_i) > 0) -> +0 + infinity j] +array_api_tests/test_special_cases.py::test_unary[tanh(real(x_i) is +infinity and isfinite(imag(x_i)) and imag(x_i) > 0) -> 1 + 0j] + diff --git a/numpy-xfails.txt b/numpy-xfails.txt index 632b4ec3..2150203c 100644 --- a/numpy-xfails.txt +++ b/numpy-xfails.txt @@ -39,3 +39,16 @@ array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i > 0 and x2_i is -infinity) -> -0] array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(isfinite(x1_i) and x1_i < 0 and x2_i is +infinity) -> -0] +# complex special cases +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is +infinity and imag(x_i) is +0) -> +infinity + 0j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is -infinity and imag(x_i) is +infinity) -> -1 + 0j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is +infinity and imag(x_i) is +infinity) -> infinity + NaN j] array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is -infinity and imag(x_i) is NaN) -> -1 + 0j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is +infinity and imag(x_i) is NaN) -> infinity + NaN j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is NaN and imag(x_i) is +0) -> NaN + 0j] +array_api_tests/test_special_cases.py::test_unary[sqrt(real(x_i) is +infinity and isfinite(imag(x_i)) and imag(x_i) > 0) -> +0 + infinity j] +array_api_tests/test_special_cases.py::test_unary[tanh(real(x_i) is +infinity and isfinite(imag(x_i)) and imag(x_i) > 0) -> 1 + 0j] + +array_api_tests/test_special_cases.py::test_unary[acosh(real(x_i) is +0 and imag(x_i) is NaN) -> NaN \xb1 \u03c0j/2] +array_api_tests/test_special_cases.py::test_unary[expm1((real(x_i) is +0 or real(x_i) == -0) and imag(x_i) is +0) -> 0 + 0j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is +infinity and imag(x_i) is +infinity) -> infinity + NaN j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is -infinity and imag(x_i) is NaN) -> -1 + 0j] diff --git a/torch-xfails.txt b/torch-xfails.txt index 989df0c8..589f3b52 100644 --- a/torch-xfails.txt +++ b/torch-xfails.txt @@ -108,6 +108,21 @@ array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity array_api_tests/test_special_cases.py::test_iop[__imod__(x1_i is +0 and x2_i < 0) -> -0] array_api_tests/test_special_cases.py::test_iop[__imod__(x1_i is -0 and x2_i > 0) -> +0] +# complex cases +array_api_tests/test_special_cases.py::test_unary[acos((real(x_i) is +0 or real(x_i) == -0) and imag(x_i) is +0) -> \u03c0/2 - 0j] +array_api_tests/test_special_cases.py::test_unary[expm1((real(x_i) is +0 or real(x_i) == -0) and imag(x_i) is +0) -> 0 + 0j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is +infinity and imag(x_i) is +0) -> +infinity + 0j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is -infinity and imag(x_i) is +infinity) -> -1 + 0j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is +infinity and imag(x_i) is +infinity) -> infinity + NaN j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is -infinity and imag(x_i) is NaN) -> -1 + 0j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is +infinity and imag(x_i) is NaN) -> infinity + NaN j] +array_api_tests/test_special_cases.py::test_unary[expm1(real(x_i) is NaN and imag(x_i) is +0) -> NaN + 0j] +array_api_tests/test_special_cases.py::test_unary[log1p(isfinite(real(x_i)) and imag(x_i) is +infinity) -> +infinity + \u03c0j/2] +array_api_tests/test_special_cases.py::test_unary[log1p((real(x_i) is +infinity or real(x_i) == -infinity) and imag(x_i) is NaN) -> +infinity + NaN j] +array_api_tests/test_special_cases.py::test_unary[log1p(real(x_i) is NaN and imag(x_i) is +infinity) -> +infinity + NaN j] +array_api_tests/test_special_cases.py::test_unary[sqrt(real(x_i) is +infinity and isfinite(imag(x_i)) and imag(x_i) > 0) -> +0 + infinity j] +array_api_tests/test_special_cases.py::test_unary[tanh(real(x_i) is +infinity and isfinite(imag(x_i)) and imag(x_i) > 0) -> 1 + 0j] + # Float correction is not supported by pytorch # (https://github.com/data-apis/array-api-tests/issues/168) array_api_tests/test_statistical_functions.py::test_std From 9b28a4c1225c7d1be8fd0b30e921598e5f6d35a6 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 2 Feb 2026 10:22:04 +0000 Subject: [PATCH 105/151] Bump dawidd6/action-download-artifact from 12 to 14 in the actions group Bumps the actions group with 1 update: [dawidd6/action-download-artifact](https://github.com/dawidd6/action-download-artifact). Updates `dawidd6/action-download-artifact` from 12 to 14 - [Release notes](https://github.com/dawidd6/action-download-artifact/releases) - [Commits](https://github.com/dawidd6/action-download-artifact/compare/v12...v14) --- updated-dependencies: - dependency-name: dawidd6/action-download-artifact dependency-version: '14' dependency-type: direct:production update-type: version-update:semver-major dependency-group: actions ... Signed-off-by: dependabot[bot] --- .github/workflows/docs-deploy.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/docs-deploy.yml b/.github/workflows/docs-deploy.yml index 76c76e31..07d6f575 100644 --- a/.github/workflows/docs-deploy.yml +++ b/.github/workflows/docs-deploy.yml @@ -13,7 +13,7 @@ jobs: steps: - uses: actions/checkout@v6 - name: Download Artifact - uses: dawidd6/action-download-artifact@v12 + uses: dawidd6/action-download-artifact@v14 with: workflow: docs-build.yml name: docs-build From 35b631f718e47985b82119a0b02b6342f0304bc1 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Thu, 5 Feb 2026 16:10:13 +0100 Subject: [PATCH 106/151] WIP: add axis tuple support to torch.expand_dims --- array_api_compat/torch/_aliases.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 4b232f84..512c1060 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -690,9 +690,24 @@ def triu(x: Array, /, *, k: int = 0) -> Array: return torch.triu(x, k) # Functions that aren't in torch https://github.com/pytorch/pytorch/issues/58742 -def expand_dims(x: Array, /, *, axis: int = 0) -> Array: - return torch.unsqueeze(x, axis) +def expand_dims(x: Array, /, *, axis: int | tuple[int, ...]) -> Array: + if isinstance(axis, int): + return torch.unsqueeze(x, axis) + else: + # follow https://github.com/numpy/numpy/blob/maintenance/2.4.x/numpy/lib/_shape_base_impl.py#L596-L602 + y_ndim = x.ndim + len(axis) + + # normalize + n_axis = tuple(ax + y_ndim if ax < 0 else ax for ax in axis) + if (len(n_axis) != len(set(n_axis)) or + _builtin_any(ax < 0 or ax >= y_ndim for ax in n_axis) + ): + raise ValueError(f"{axis=} not allowed for {x.shape = }") + + shape_it = iter(x.shape) + shape = [1 if ax in n_axis else next(shape_it) for ax in range(y_ndim)] + return torch.reshape(x, shape) def astype( x: Array, From 6048e484bda16d32ceeabbd4a7a1c67d56268bdf Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sun, 8 Feb 2026 20:23:54 +0100 Subject: [PATCH 107/151] TST: remove test_signature xfails --- cupy-xfails.txt | 5 ----- numpy-1-26-xfails.txt | 5 ----- numpy-dev-xfails.txt | 7 ------- numpy-xfails.txt | 7 ------- torch-xfails.txt | 11 ----------- 5 files changed, 35 deletions(-) diff --git a/cupy-xfails.txt b/cupy-xfails.txt index 25ff654c..fd92bde5 100644 --- a/cupy-xfails.txt +++ b/cupy-xfails.txt @@ -203,11 +203,6 @@ array_api_tests/test_signatures.py::test_func_signature[from_dlpack] array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__] # 2024.12 support -array_api_tests/test_signatures.py::test_func_signature[bitwise_and] -array_api_tests/test_signatures.py::test_func_signature[bitwise_left_shift] -array_api_tests/test_signatures.py::test_func_signature[bitwise_or] -array_api_tests/test_signatures.py::test_func_signature[bitwise_right_shift] -array_api_tests/test_signatures.py::test_func_signature[bitwise_xor] array_api_tests/test_special_cases.py::test_binary[nextafter(x1_i is +0 and x2_i is -0) -> -0] # cupy 13.x follows numpy 1.x w/o weak promotion: result_type(int32, uint8, 1) != result_type(int32, uint8) diff --git a/numpy-1-26-xfails.txt b/numpy-1-26-xfails.txt index 060bacc9..d439990a 100644 --- a/numpy-1-26-xfails.txt +++ b/numpy-1-26-xfails.txt @@ -43,11 +43,6 @@ array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__] array_api_tests/test_manipulation_functions.py::test_repeat # 2024.12 support -array_api_tests/test_signatures.py::test_func_signature[bitwise_and] -array_api_tests/test_signatures.py::test_func_signature[bitwise_left_shift] -array_api_tests/test_signatures.py::test_func_signature[bitwise_or] -array_api_tests/test_signatures.py::test_func_signature[bitwise_right_shift] -array_api_tests/test_signatures.py::test_func_signature[bitwise_xor] array_api_tests/test_data_type_functions.py::TestResultType::test_with_scalars array_api_tests/test_operators_and_elementwise_functions.py::test_where_with_scalars diff --git a/numpy-dev-xfails.txt b/numpy-dev-xfails.txt index 25552853..57f3d48f 100644 --- a/numpy-dev-xfails.txt +++ b/numpy-dev-xfails.txt @@ -12,13 +12,6 @@ array_api_tests/test_signatures.py::test_extension_func_signature[linalg.vecdot] # uint64 repeats not supported array_api_tests/test_manipulation_functions.py::test_repeat -# 2024.12 support -array_api_tests/test_signatures.py::test_func_signature[bitwise_and] -array_api_tests/test_signatures.py::test_func_signature[bitwise_left_shift] -array_api_tests/test_signatures.py::test_func_signature[bitwise_or] -array_api_tests/test_signatures.py::test_func_signature[bitwise_right_shift] -array_api_tests/test_signatures.py::test_func_signature[bitwise_xor] - # Stubs have a comment: (**note**: libraries may return ``NaN`` to match Python behavior.); Apparently, NumPy does just that array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] diff --git a/numpy-xfails.txt b/numpy-xfails.txt index 2150203c..9c9afe2e 100644 --- a/numpy-xfails.txt +++ b/numpy-xfails.txt @@ -12,13 +12,6 @@ array_api_tests/test_signatures.py::test_extension_func_signature[linalg.vecdot] # uint64 repeats not supported array_api_tests/test_manipulation_functions.py::test_repeat -# 2024.12 support -array_api_tests/test_signatures.py::test_func_signature[bitwise_and] -array_api_tests/test_signatures.py::test_func_signature[bitwise_left_shift] -array_api_tests/test_signatures.py::test_func_signature[bitwise_or] -array_api_tests/test_signatures.py::test_func_signature[bitwise_right_shift] -array_api_tests/test_signatures.py::test_func_signature[bitwise_xor] - # Stubs have a comment: (**note**: libraries may return ``NaN`` to match Python behavior.); Apparently, NumPy does just that array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] diff --git a/torch-xfails.txt b/torch-xfails.txt index 589f3b52..32779415 100644 --- a/torch-xfails.txt +++ b/torch-xfails.txt @@ -146,17 +146,6 @@ array_api_tests/test_signatures.py::test_func_signature[from_dlpack] # Argument 'max_version' missing from signature array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__] -# 2024.12 support -array_api_tests/test_signatures.py::test_func_signature[bitwise_and] -array_api_tests/test_signatures.py::test_func_signature[bitwise_left_shift] -array_api_tests/test_signatures.py::test_func_signature[bitwise_or] -array_api_tests/test_signatures.py::test_func_signature[bitwise_right_shift] -array_api_tests/test_signatures.py::test_func_signature[bitwise_xor] -array_api_tests/test_signatures.py::test_array_method_signature[__and__] -array_api_tests/test_signatures.py::test_array_method_signature[__lshift__] -array_api_tests/test_signatures.py::test_array_method_signature[__or__] -array_api_tests/test_signatures.py::test_array_method_signature[__rshift__] -array_api_tests/test_signatures.py::test_array_method_signature[__xor__] # 2024.12 support: binary functions reject python scalar arguments array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[atan2] From d2dfd58aeaddbcb612da6fae1224170bc45425ee Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 9 Feb 2026 19:11:05 +0000 Subject: [PATCH 108/151] MAINT: update cupy xfails --- cupy-xfails.txt | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/cupy-xfails.txt b/cupy-xfails.txt index fd92bde5..e63646f5 100644 --- a/cupy-xfails.txt +++ b/cupy-xfails.txt @@ -44,6 +44,11 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_sc array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[copysign] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[divide] array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[maximum] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[add] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[minimum] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[multiply] +array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[subtract] + # cupy (arg)min/max wrong with infinities # https://github.com/cupy/cupy/issues/7424 @@ -176,7 +181,7 @@ array_api_tests/test_special_cases.py::test_unary[tan(x_i is -0) -> -0] array_api_tests/test_special_cases.py::test_unary[tanh(x_i is -0) -> -0] array_api_tests/test_special_cases.py::test_unary[trunc(x_i is -0) -> -0] -# complex cases +# complex spec cases array_api_tests/test_special_cases.py::test_unary[acosh(real(x_i) is +0 and imag(x_i) is NaN) -> NaN \xb1 \u03c0j/2] array_api_tests/test_special_cases.py::test_unary[log(real(x_i) is -0 and imag(x_i) is +0) -> -infinity + \u03c0j] array_api_tests/test_special_cases.py::test_unary[log((real(x_i) is +infinity or real(x_i) == -infinity) and imag(x_i) is NaN) -> +infinity + NaN j] @@ -186,6 +191,7 @@ array_api_tests/test_special_cases.py::test_unary[log1p(real(x_i) is NaN and ima array_api_tests/test_special_cases.py::test_unary[sqrt(real(x_i) is +infinity and isfinite(imag(x_i)) and imag(x_i) > 0) -> +0 + infinity j] array_api_tests/test_special_cases.py::test_unary[tanh(real(x_i) is +0 and imag(x_i) is +infinity) -> +0 + NaN j] array_api_tests/test_special_cases.py::test_unary[tanh(real(x_i) is +0 and imag(x_i) is NaN) -> +0 + NaN j] +array_api_tests/test_special_cases.py::test_unary[tanh(real(x_i) is +infinity and isfinite(imag(x_i)) and imag(x_i) > 0) -> 1 + 0j] # CuPy gives the wrong shape for n-dim fft funcs. See # https://github.com/data-apis/array-api-compat/pull/78#issuecomment-1984527870 From 955d146095c04b8992e6a3599901ef030e7732ab Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Tue, 17 Feb 2026 19:37:20 +0000 Subject: [PATCH 109/151] BUG: torch: cast short ints in repeat Otherwise, "repeat_interleave is not implemented for Char" etc. --- array_api_compat/torch/_aliases.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 4b232f84..4a7e9c00 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -573,6 +573,10 @@ def count_nonzero( # "repeat" is torch.repeat_interleave; also the dim argument def repeat(x: Array, repeats: int | Array, /, *, axis: int | None = None) -> Array: + if isinstance(repeats, torch.Tensor) and repeats.dtype in (torch.int8, torch.int16): + # torch rejects short integers for the `repeat` argument: + # https://github.com/pytorch/pytorch/issues/151311 + repeats = repeats.to(torch.int32) return torch.repeat_interleave(x, repeats, axis) From c5a0028ef9d7ad7fd268b3dcec72167d461674c7 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Tue, 17 Feb 2026 19:48:15 +0000 Subject: [PATCH 110/151] TST: unskip test_repeat on torch and numpy --- numpy-1-22-xfails.txt | 2 -- numpy-1-26-xfails.txt | 2 -- numpy-dev-xfails.txt | 4 ---- numpy-xfails.txt | 4 ---- torch-xfails.txt | 2 -- 5 files changed, 14 deletions(-) diff --git a/numpy-1-22-xfails.txt b/numpy-1-22-xfails.txt index ab9e47a8..e2df8b47 100644 --- a/numpy-1-22-xfails.txt +++ b/numpy-1-22-xfails.txt @@ -143,8 +143,6 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_where_with_sca array_api_tests/test_has_names.py::test_has_names[creation-from_dlpack] array_api_tests/test_signatures.py::test_func_signature[from_dlpack] array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__] -# uint64 repeats not supported -array_api_tests/test_manipulation_functions.py::test_repeat # 2024.12 support array_api_tests/test_signatures.py::test_func_signature[bitwise_and] diff --git a/numpy-1-26-xfails.txt b/numpy-1-26-xfails.txt index d439990a..407cc531 100644 --- a/numpy-1-26-xfails.txt +++ b/numpy-1-26-xfails.txt @@ -39,8 +39,6 @@ array_api_tests/meta/test_hypothesis_helpers.py::test_symmetric_matrices # 2023.12 support array_api_tests/test_signatures.py::test_func_signature[from_dlpack] array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__] -# uint64 repeats not supported -array_api_tests/test_manipulation_functions.py::test_repeat # 2024.12 support array_api_tests/test_data_type_functions.py::TestResultType::test_with_scalars diff --git a/numpy-dev-xfails.txt b/numpy-dev-xfails.txt index 57f3d48f..45d3338f 100644 --- a/numpy-dev-xfails.txt +++ b/numpy-dev-xfails.txt @@ -8,10 +8,6 @@ array_api_tests/test_data_type_functions.py::test_finfo[complex64] array_api_tests/test_signatures.py::test_func_signature[vecdot] array_api_tests/test_signatures.py::test_extension_func_signature[linalg.vecdot] -# 2023.12 support -# uint64 repeats not supported -array_api_tests/test_manipulation_functions.py::test_repeat - # Stubs have a comment: (**note**: libraries may return ``NaN`` to match Python behavior.); Apparently, NumPy does just that array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] diff --git a/numpy-xfails.txt b/numpy-xfails.txt index 9c9afe2e..d8707c39 100644 --- a/numpy-xfails.txt +++ b/numpy-xfails.txt @@ -8,10 +8,6 @@ array_api_tests/test_data_type_functions.py::test_finfo[complex64] array_api_tests/test_signatures.py::test_func_signature[vecdot] array_api_tests/test_signatures.py::test_extension_func_signature[linalg.vecdot] -# 2023.12 support -# uint64 repeats not supported -array_api_tests/test_manipulation_functions.py::test_repeat - # Stubs have a comment: (**note**: libraries may return ``NaN`` to match Python behavior.); Apparently, NumPy does just that array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] diff --git a/torch-xfails.txt b/torch-xfails.txt index 32779415..3c2a0028 100644 --- a/torch-xfails.txt +++ b/torch-xfails.txt @@ -139,8 +139,6 @@ array_api_tests/test_data_type_functions.py::test_finfo_dtype array_api_tests/test_data_type_functions.py::test_iinfo_dtype # 2023.12 support -# https://github.com/pytorch/pytorch/issues/151311: torch.repeat_interleave rejects short integers -array_api_tests/test_manipulation_functions.py::test_repeat # Argument 'device' missing from signature array_api_tests/test_signatures.py::test_func_signature[from_dlpack] # Argument 'max_version' missing from signature From 0e68eac4692eeeb6f6d4bde3b44320b536542c6e Mon Sep 17 00:00:00 2001 From: Josh Soref <2119212+jsoref@users.noreply.github.com> Date: Wed, 18 Feb 2026 10:27:18 -0500 Subject: [PATCH 111/151] Merge pull request #389 from jsoref/spelling Fix several typos --- array_api_compat/common/_helpers.py | 2 +- array_api_compat/dask/array/linalg.py | 2 +- docs/changelog.md | 4 ++-- docs/dev/releasing.md | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 8194a083..bc90d208 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -593,7 +593,7 @@ def array_namespace( use_compat: bool or None If None (the default), the native namespace will be returned if it is - already array API compatible, otherwise a compat wrapper is used. If + already array API compatible; otherwise, a compat wrapper is used. If True, the compat library wrapped library will be returned. If False, the native library namespace is returned. diff --git a/array_api_compat/dask/array/linalg.py b/array_api_compat/dask/array/linalg.py index 6b3c1011..a9be5d5f 100644 --- a/array_api_compat/dask/array/linalg.py +++ b/array_api_compat/dask/array/linalg.py @@ -43,7 +43,7 @@ def qr( # type: ignore[no-redef] # and dask doesn't have the full_matrices keyword def svd(x: Array, full_matrices: bool = True, **kwargs: object) -> SVDResult: # type: ignore[no-redef] if full_matrices: - raise ValueError("full_matrics=True is not supported by dask.") + raise ValueError("full_matrices=True is not supported by dask.") return da.linalg.svd(x, coerce_signs=False, **kwargs) def svdvals(x: Array) -> Array: diff --git a/docs/changelog.md b/docs/changelog.md index fe07c9e0..a9b355a0 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -221,7 +221,7 @@ Thomas Li `xp.__array_namespace_info__()`. - Various fixes to the `clip()` wrappers. -- `torch.conj` now wrapps `torch.conj_physical`, which makes a copy rather +- `torch.conj` now wraps `torch.conj_physical`, which makes a copy rather than setting the conjugation bit, as arrays with the conjugation bit set do not support some APIs. @@ -298,7 +298,7 @@ Thomas Li - New flag `use_compat` to {func}`~.array_namespace` to force the use or non-use of the compat wrapper namespace. The default is to return a compat - namespace when it is appropiate. + namespace when it is appropriate. - Fix the `copy` flag to `asarray` for NumPy, CuPy, and Dask. diff --git a/docs/dev/releasing.md b/docs/dev/releasing.md index f6fca85c..21d3c36a 100644 --- a/docs/dev/releasing.md +++ b/docs/dev/releasing.md @@ -24,7 +24,7 @@ This does mean you can ignore CI failures, but ideally you should fix any failures or update the `*-xfails.txt` files before tagging, so that CI and - the CuPy tests fully pass. Otherwise it will be hard to tell what things are + the CuPy tests fully pass. Otherwise, it will be hard to tell what things are breaking in the future. It's also a good idea to remove any xpasses from those files (but be aware that some xfails are from flaky failures, so unless you know the underlying issue has been fixed, an xpass test is From 6652499b56c8ac0301b499b48e0494fcc679bae3 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 23 Feb 2026 10:07:39 +0000 Subject: [PATCH 112/151] Bump dawidd6/action-download-artifact from 14 to 15 in the actions group Bumps the actions group with 1 update: [dawidd6/action-download-artifact](https://github.com/dawidd6/action-download-artifact). Updates `dawidd6/action-download-artifact` from 14 to 15 - [Release notes](https://github.com/dawidd6/action-download-artifact/releases) - [Commits](https://github.com/dawidd6/action-download-artifact/compare/v14...v15) --- updated-dependencies: - dependency-name: dawidd6/action-download-artifact dependency-version: '15' dependency-type: direct:production update-type: version-update:semver-major dependency-group: actions ... Signed-off-by: dependabot[bot] --- .github/workflows/docs-deploy.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/docs-deploy.yml b/.github/workflows/docs-deploy.yml index 07d6f575..d237595a 100644 --- a/.github/workflows/docs-deploy.yml +++ b/.github/workflows/docs-deploy.yml @@ -13,7 +13,7 @@ jobs: steps: - uses: actions/checkout@v6 - name: Download Artifact - uses: dawidd6/action-download-artifact@v14 + uses: dawidd6/action-download-artifact@v15 with: workflow: docs-build.yml name: docs-build From f15dc6cccf5da66ad9144d21e713317381bc3286 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 23 Feb 2026 12:48:05 +0100 Subject: [PATCH 113/151] CI: run tests against the 2025.12 revision Update ARRAY_API_TESTS_VERSION to 2025.12 --- .github/workflows/array-api-tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/array-api-tests.yml b/.github/workflows/array-api-tests.yml index 8d78225c..67ae76f5 100644 --- a/.github/workflows/array-api-tests.yml +++ b/.github/workflows/array-api-tests.yml @@ -84,7 +84,7 @@ jobs: - name: Run the array API testsuite (${{ inputs.package-name }}) env: ARRAY_API_TESTS_MODULE: array_api_compat.${{ inputs.module-name || inputs.package-name }} - ARRAY_API_TESTS_VERSION: 2024.12 + ARRAY_API_TESTS_VERSION: 2025.12 # This enables the NEP 50 type promotion behavior (without it a lot of # tests fail on bad scalar type promotion behavior) NPY_PROMOTION_STATE: weak From da06daed0b5e93f6cd7d5bce7c36d4057759f37e Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sat, 10 Jan 2026 22:37:04 +0100 Subject: [PATCH 114/151] torch.broadcast_arrays: make it return a tuple --- array_api_compat/torch/_aliases.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 4a7e9c00..10e8186e 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -711,9 +711,9 @@ def astype( return x.to(dtype=dtype, copy=copy) -def broadcast_arrays(*arrays: Array) -> list[Array]: +def broadcast_arrays(*arrays: Array) -> tuple[Array, ...]: shape = torch.broadcast_shapes(*[a.shape for a in arrays]) - return [torch.broadcast_to(a, shape) for a in arrays] + return tuple(torch.broadcast_to(a, shape) for a in arrays) # Note that these named tuples aren't actually part of the standard namespace, # but I don't see any issue with exporting the names here regardless. From ddf995387f8337e19d80a463d1384501c1fc708b Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sat, 10 Jan 2026 22:37:18 +0100 Subject: [PATCH 115/151] cupy.broadcast_arrays: make it return a tuple --- array_api_compat/cupy/_aliases.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index 2e512fc8..badfe390 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -139,6 +139,11 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array: return cp.take_along_axis(x, indices, axis=axis) +# https://github.com/cupy/cupy/pull/9582 +def broadcast_arrays(*arrays: Array) -> tuple[Array, ...]: + return tuple(cp.broadcast_arrays(*arrays)) + + # These functions are completely new here. If the library already has them # (i.e., numpy 2.0), use the library version instead of our wrapper. if hasattr(cp, 'vecdot'): @@ -161,7 +166,8 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array: 'atan2', 'atanh', 'bitwise_left_shift', 'bitwise_invert', 'bitwise_right_shift', 'bool', 'concat', 'count_nonzero', 'pow', 'sign', - 'ceil', 'floor', 'trunc', 'take_along_axis'] + 'ceil', 'floor', 'trunc', 'take_along_axis', + 'broadcast_arrays',] def __dir__() -> list[str]: From 94c1706d20a2b07331cbb2028ba12696e8c6e4b6 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sun, 11 Jan 2026 11:07:42 +0100 Subject: [PATCH 116/151] torch.meshgrid: make it return tuple, not list --- array_api_compat/torch/_aliases.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 10e8186e..27f0a263 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -897,10 +897,11 @@ def sign(x: Array, /) -> Array: return out -def meshgrid(*arrays: Array, indexing: Literal['xy', 'ij'] = 'xy') -> list[Array]: - # enforce the default of 'xy' - # TODO: is the return type a list or a tuple - return list(torch.meshgrid(*arrays, indexing=indexing)) +def meshgrid(*arrays: Array, indexing: Literal['xy', 'ij'] = 'xy') -> tuple[Array, ...]: + # torch <= 2.9 emits a UserWarning: "torch.meshgrid: in an upcoming release, it + # will be required to pass the indexing argument." + # Thus always pass it explicitly. + return torch.meshgrid(*arrays, indexing=indexing) __all__ = ['asarray', 'result_type', 'can_cast', From af5fd5ca220c5b177d92faa21e3ed28ba9b646d7 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sun, 11 Jan 2026 11:19:38 +0100 Subject: [PATCH 117/151] __array_namespace_info().devices() : returns a tuple not list --- array_api_compat/cupy/_info.py | 2 +- array_api_compat/dask/array/_info.py | 4 ++-- array_api_compat/numpy/_info.py | 4 ++-- array_api_compat/torch/_info.py | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/array_api_compat/cupy/_info.py b/array_api_compat/cupy/_info.py index 78e48a33..aef10e85 100644 --- a/array_api_compat/cupy/_info.py +++ b/array_api_compat/cupy/_info.py @@ -333,4 +333,4 @@ def devices(self): __array_namespace_info__.dtypes """ - return [cuda.Device(i) for i in range(cuda.runtime.getDeviceCount())] + return tuple(cuda.Device(i) for i in range(cuda.runtime.getDeviceCount())) diff --git a/array_api_compat/dask/array/_info.py b/array_api_compat/dask/array/_info.py index 2f39fc4b..3a7285d5 100644 --- a/array_api_compat/dask/array/_info.py +++ b/array_api_compat/dask/array/_info.py @@ -379,7 +379,7 @@ def dtypes( return res raise ValueError(f"unsupported kind: {kind!r}") - def devices(self) -> list[Device]: + def devices(self) -> tuple[Device]: """ The devices supported by Dask. @@ -404,4 +404,4 @@ def devices(self) -> list[Device]: ['cpu', DASK_DEVICE] """ - return ["cpu", _DASK_DEVICE] + return ("cpu", _DASK_DEVICE) diff --git a/array_api_compat/numpy/_info.py b/array_api_compat/numpy/_info.py index c625c13e..9ba004da 100644 --- a/array_api_compat/numpy/_info.py +++ b/array_api_compat/numpy/_info.py @@ -332,7 +332,7 @@ def dtypes( return res raise ValueError(f"unsupported kind: {kind!r}") - def devices(self) -> list[Device]: + def devices(self) -> tuple[Device]: """ The devices supported by NumPy. @@ -357,7 +357,7 @@ def devices(self) -> list[Device]: ['cpu'] """ - return ["cpu"] + return ("cpu",) __all__ = ["__array_namespace_info__"] diff --git a/array_api_compat/torch/_info.py b/array_api_compat/torch/_info.py index 818e5d37..050c7846 100644 --- a/array_api_compat/torch/_info.py +++ b/array_api_compat/torch/_info.py @@ -366,4 +366,4 @@ def devices(self): break i += 1 - return devices + return tuple(devices) From 0959644ac11a1b38c7f5e3b45430417e1ffc3236 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 23 Feb 2026 13:22:07 +0100 Subject: [PATCH 118/151] TST: update xfails for the tuples/lists spec change --- dask-xfails.txt | 7 +++++++ numpy-1-22-xfails.txt | 7 +++++++ numpy-1-26-xfails.txt | 8 +++++++- torch-xfails.txt | 5 +++++ 4 files changed, 26 insertions(+), 1 deletion(-) diff --git a/dask-xfails.txt b/dask-xfails.txt index 5ce24254..4d5a89af 100644 --- a/dask-xfails.txt +++ b/dask-xfails.txt @@ -129,6 +129,13 @@ array_api_tests/test_linalg.py::test_matrix_norm array_api_tests/test_linalg.py::test_qr array_api_tests/test_manipulation_functions.py::test_roll +# 2025.12 support +array_api_tests/test_has_names.py::test_has_names[manipulation-broadcast_shapes] +array_api_tests/test_signatures.py::test_func_signature[broadcast_shapes] +array_api_tests/test_data_type_functions.py::TestBroadcastShapes::test_broadcast_shapes +array_api_tests/test_data_type_functions.py::TestBroadcastShapes::test_empty +array_api_tests/test_data_type_functions.py::TestBroadcastShapes::test_error + # Stubs have a comment: (**note**: libraries may return ``NaN`` to match Python behavior.) array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] diff --git a/numpy-1-22-xfails.txt b/numpy-1-22-xfails.txt index e2df8b47..c2e13d35 100644 --- a/numpy-1-22-xfails.txt +++ b/numpy-1-22-xfails.txt @@ -152,6 +152,13 @@ array_api_tests/test_signatures.py::test_func_signature[bitwise_right_shift] array_api_tests/test_signatures.py::test_func_signature[bitwise_xor] array_api_tests/test_data_type_functions.py::TestResultType::test_with_scalars +# 2025.12 support + +# older numpies return lists not tuples +array_api_tests/test_creation_functions.py::test_meshgrid +array_api_tests/test_data_type_functions.py::test_broadcast_arrays + + # Stubs have a comment: (**note**: libraries may return ``NaN`` to match Python behavior.); Apparently,NumPy does just that array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] diff --git a/numpy-1-26-xfails.txt b/numpy-1-26-xfails.txt index 407cc531..f09be4a4 100644 --- a/numpy-1-26-xfails.txt +++ b/numpy-1-26-xfails.txt @@ -42,9 +42,15 @@ array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__] # 2024.12 support array_api_tests/test_data_type_functions.py::TestResultType::test_with_scalars - array_api_tests/test_operators_and_elementwise_functions.py::test_where_with_scalars +# 2025.12 support + +# older numpies return lists not tuples +array_api_tests/test_creation_functions.py::test_meshgrid +array_api_tests/test_data_type_functions.py::test_broadcast_arrays + + # Stubs have a comment: (**note**: libraries may return ``NaN`` to match Python behavior.); Apparently, NumPy does just that array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] diff --git a/torch-xfails.txt b/torch-xfails.txt index 3c2a0028..84271a56 100644 --- a/torch-xfails.txt +++ b/torch-xfails.txt @@ -144,6 +144,11 @@ array_api_tests/test_signatures.py::test_func_signature[from_dlpack] # Argument 'max_version' missing from signature array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__] +# 2025.12 support + +# broadcast_shapes emits a RuntimeError where the spec says ValueError +array_api_tests/test_data_type_functions.py::TestBroadcastShapes::test_error + # 2024.12 support: binary functions reject python scalar arguments array_api_tests/test_operators_and_elementwise_functions.py::test_binary_with_scalars_real[atan2] From d2e2f0ec06a94cdb98dee53444926b28bf9bc5b2 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Mon, 23 Feb 2026 13:07:51 +0000 Subject: [PATCH 119/151] cupy.meshgrid: return a tuple not a list --- array_api_compat/cupy/_aliases.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index badfe390..f91805f2 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -1,6 +1,7 @@ from __future__ import annotations from builtins import bool as py_bool +from typing import Literal import cupy as cp @@ -144,6 +145,10 @@ def broadcast_arrays(*arrays: Array) -> tuple[Array, ...]: return tuple(cp.broadcast_arrays(*arrays)) +def meshgrid(*arrays: Array, indexing: Literal['xy', 'ij'] = 'xy') -> tuple[Array, ...]: + return tuple(cp.meshgrid(*arrays, indexing=indexing)) + + # These functions are completely new here. If the library already has them # (i.e., numpy 2.0), use the library version instead of our wrapper. if hasattr(cp, 'vecdot'): @@ -167,7 +172,7 @@ def broadcast_arrays(*arrays: Array) -> tuple[Array, ...]: 'bitwise_invert', 'bitwise_right_shift', 'bool', 'concat', 'count_nonzero', 'pow', 'sign', 'ceil', 'floor', 'trunc', 'take_along_axis', - 'broadcast_arrays',] + 'broadcast_arrays', 'meshgrid'] def __dir__() -> list[str]: From cfb77ee55e70d0662e74f0d6a4c436f9db7d838a Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sat, 10 Jan 2026 19:32:35 +0100 Subject: [PATCH 120/151] TST: dask.linalg.eig is not a thing --- dask-xfails.txt | 3 +++ 1 file changed, 3 insertions(+) diff --git a/dask-xfails.txt b/dask-xfails.txt index 4d5a89af..bc153d87 100644 --- a/dask-xfails.txt +++ b/dask-xfails.txt @@ -136,6 +136,9 @@ array_api_tests/test_data_type_functions.py::TestBroadcastShapes::test_broadcast array_api_tests/test_data_type_functions.py::TestBroadcastShapes::test_empty array_api_tests/test_data_type_functions.py::TestBroadcastShapes::test_error +array_api_tests/test_linalg.py::test_eig +array_api_tests/test_linalg.py::test_eigvals + # Stubs have a comment: (**note**: libraries may return ``NaN`` to match Python behavior.) array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] From b01b04cd2710566e2df7b84f852e748014fb1b3b Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sat, 10 Jan 2026 20:41:45 +0100 Subject: [PATCH 121/151] ENH: numpy: add linalg.{eig,eigvals} We need a wrapper because numpy currently returns `float|complex`. Implementation-wise, follow `linalg.solve` and copy-paste relevant numpy code with minimal required modifications. --- array_api_compat/common/_linalg.py | 4 ++ array_api_compat/numpy/linalg.py | 80 ++++++++++++++++++++++++++++++ 2 files changed, 84 insertions(+) diff --git a/array_api_compat/common/_linalg.py b/array_api_compat/common/_linalg.py index 69672af7..14b560d1 100644 --- a/array_api_compat/common/_linalg.py +++ b/array_api_compat/common/_linalg.py @@ -34,6 +34,10 @@ class EighResult(NamedTuple): eigenvalues: Array eigenvectors: Array +class EigResult(NamedTuple): + eigenvalues: Array + eigenvectors: Array + class QRResult(NamedTuple): Q: Array R: Array diff --git a/array_api_compat/numpy/linalg.py b/array_api_compat/numpy/linalg.py index 7168441c..474efe50 100644 --- a/array_api_compat/numpy/linalg.py +++ b/array_api_compat/numpy/linalg.py @@ -19,6 +19,7 @@ cross = get_xp(np)(_linalg.cross) outer = get_xp(np)(_linalg.outer) EighResult = _linalg.EighResult +EigResult = _linalg.EigResult QRResult = _linalg.QRResult SlogdetResult = _linalg.SlogdetResult SVDResult = _linalg.SVDResult @@ -97,6 +98,85 @@ def solve(x1: Array, x2: Array, /) -> Array: return wrap(r.astype(result_t, copy=False)) +# Unlike numpy.linalg.eig, Array API version always returns complex results + +def eig(x: Array, /) -> tuple[Array, Array]: + try: + from numpy.linalg._linalg import ( # type: ignore[attr-defined] + _assert_stacked_square, + _assert_finite, + _commonType, + _makearray, + _raise_linalgerror_eigenvalues_nonconvergence, + isComplexType, + _complexType, + ) + except ImportError: + from numpy.linalg.linalg import ( # type: ignore[attr-defined] + _assert_stacked_square, + _assert_finite, + _commonType, + _makearray, + _raise_linalgerror_eigenvalues_nonconvergence, + isComplexType, + _complexType, + ) + from numpy.linalg import _umath_linalg + + x, wrap = _makearray(x) + _assert_stacked_square(x) + _assert_finite(x) + t, result_t = _commonType(x) + + signature = 'D->DD' if isComplexType(t) else 'd->DD' + with np.errstate(call=_raise_linalgerror_eigenvalues_nonconvergence, + invalid='call', over='ignore', divide='ignore', + under='ignore'): + w, vt = _umath_linalg.eig(x, signature=signature) + + result_t = _complexType(result_t) + vt = vt.astype(result_t, copy=False) + return EigResult(w.astype(result_t, copy=False), wrap(vt)) + + +def eigvals(x: Array, /) -> Array: + try: + from numpy.linalg._linalg import ( # type: ignore[attr-defined] + _assert_stacked_square, + _assert_finite, + _commonType, + _makearray, + _raise_linalgerror_eigenvalues_nonconvergence, + isComplexType, + _complexType, + ) + except ImportError: + from numpy.linalg.linalg import ( # type: ignore[attr-defined] + _assert_stacked_square, + _assert_finite, + _commonType, + _makearray, + _raise_linalgerror_eigenvalues_nonconvergence, + isComplexType, + _complexType, + ) + from numpy.linalg import _umath_linalg + + x, wrap = _makearray(x) + _assert_stacked_square(x) + _assert_finite(x) + t, result_t = _commonType(x) + + signature = 'D->D' if isComplexType(t) else 'd->D' + with np.errstate(call=_raise_linalgerror_eigenvalues_nonconvergence, + invalid='call', over='ignore', divide='ignore', + under='ignore'): + w = _umath_linalg.eigvals(x, signature=signature) + + result_t = _complexType(result_t) + return w.astype(result_t, copy=False) + + # These functions are completely new here. If the library already has them # (i.e., numpy 2.0), use the library version instead of our wrapper. if hasattr(np.linalg, "vector_norm"): From d4c6f5d608443a142161478099cd88771eb04549 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Sat, 10 Jan 2026 20:48:04 +0100 Subject: [PATCH 122/151] TST: cupy: xfail eig tests --- cupy-xfails.txt | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/cupy-xfails.txt b/cupy-xfails.txt index e63646f5..7e147635 100644 --- a/cupy-xfails.txt +++ b/cupy-xfails.txt @@ -24,6 +24,10 @@ array_api_tests/test_has_names.py::test_has_names[array_attribute-mT] array_api_tests/test_linalg.py::test_solve +# 2025.12 support; {eig,eigvals} are new in CuPy 14 +array_api_tests/test_linalg.py::test_eig +array_api_tests/test_linalg.py::test_eigvals + # We cannot modify array methods array_api_tests/test_operators_and_elementwise_functions.py::test_divide[__truediv__(x, s)] array_api_tests/test_operators_and_elementwise_functions.py::test_floor_divide[__floordiv__(x, s)] From 661b5312bc369f726af8126de40304dbca849b8f Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Thu, 8 Jan 2026 20:15:34 +0000 Subject: [PATCH 123/151] ENH: cupy: add a workaround for cp.searchorted 2nd argument Array API 2025.12 allows python scalars for the x2 argument of `searchsorted`. CuPy only supports python scalars for x2 from CuPy 14.0. Until this is the minimum supported version, array-api-compat needs a workaround. --- array_api_compat/cupy/_aliases.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index f91805f2..7b7bfda6 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -149,6 +149,24 @@ def meshgrid(*arrays: Array, indexing: Literal['xy', 'ij'] = 'xy') -> tuple[Arra return tuple(cp.meshgrid(*arrays, indexing=indexing)) +# Match https://github.com/cupy/cupy/pull/9512/ until cupy v14 is the minimum +# supported version +def searchsorted( + x1: Array, + x2: Array | int | float, + /, + *, + side: Literal['left', 'right'] = 'left', + sorter: Array | None = None +) -> Array: + if not isinstance(x2, cp.ndarray): + if not isinstance(x2, int | float | complex): + raise NotImplementedError( + 'Only python scalars or ndarrays are supported for x2') + x2 = cp.asarray(x2) + return cp.searchsorted(x1, x2, side, sorter) + + # These functions are completely new here. If the library already has them # (i.e., numpy 2.0), use the library version instead of our wrapper. if hasattr(cp, 'vecdot'): @@ -172,7 +190,9 @@ def meshgrid(*arrays: Array, indexing: Literal['xy', 'ij'] = 'xy') -> tuple[Arra 'bitwise_invert', 'bitwise_right_shift', 'bool', 'concat', 'count_nonzero', 'pow', 'sign', 'ceil', 'floor', 'trunc', 'take_along_axis', - 'broadcast_arrays', 'meshgrid'] + 'broadcast_arrays', 'meshgrid', + 'searchsorted', +] def __dir__() -> list[str]: From 3ececee1317f54d095d2abff9b72d89040ac4b95 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Thu, 8 Jan 2026 21:56:33 +0100 Subject: [PATCH 124/151] TST: dask: xfail searchorted_scalars test on dask Dask does not allow scalars as arguments to searchsorted: $ ARRAY_API_TESTS_MODULE=array_api_compat.dask.array pytest array_api_tests/test_searching_functions.py::test_searchsorted_with_scalars --max-examples 500 ... @given(data=st.data()) > def test_searchsorted_with_scalars(data): ^^^ ... # call np.searchsorted for each pair of blocks in a and v > meta = np.searchsorted(a._meta, v._meta) ^^^^^^^ E AttributeError: 'int' object has no attribute '_meta' E E ========== FAILING CODE SNIPPET: E xp.searchsorted(dask.array, 0, sorter=None, **kw) with kw = {} E ==================== --- dask-xfails.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/dask-xfails.txt b/dask-xfails.txt index bc153d87..93a138cf 100644 --- a/dask-xfails.txt +++ b/dask-xfails.txt @@ -135,6 +135,7 @@ array_api_tests/test_signatures.py::test_func_signature[broadcast_shapes] array_api_tests/test_data_type_functions.py::TestBroadcastShapes::test_broadcast_shapes array_api_tests/test_data_type_functions.py::TestBroadcastShapes::test_empty array_api_tests/test_data_type_functions.py::TestBroadcastShapes::test_error +array_api_tests/test_searching_functions.py::test_searchsorted_with_scalars array_api_tests/test_linalg.py::test_eig array_api_tests/test_linalg.py::test_eigvals From f48696934c449d111a6736dec8c8ff75019d3ecf Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Tue, 24 Feb 2026 13:01:42 +0100 Subject: [PATCH 125/151] TST: add skips for dask/expand_dims and numpy<2 isin --- dask-xfails.txt | 2 ++ numpy-1-26-xfails.txt | 2 ++ 2 files changed, 4 insertions(+) diff --git a/dask-xfails.txt b/dask-xfails.txt index 93a138cf..c2c54af2 100644 --- a/dask-xfails.txt +++ b/dask-xfails.txt @@ -140,6 +140,8 @@ array_api_tests/test_searching_functions.py::test_searchsorted_with_scalars array_api_tests/test_linalg.py::test_eig array_api_tests/test_linalg.py::test_eigvals +array_api_tests/test_manipulation_functions.py::TestExpandDims::test_expand_dims_tuples + # Stubs have a comment: (**note**: libraries may return ``NaN`` to match Python behavior.) array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity] diff --git a/numpy-1-26-xfails.txt b/numpy-1-26-xfails.txt index f09be4a4..831de9e4 100644 --- a/numpy-1-26-xfails.txt +++ b/numpy-1-26-xfails.txt @@ -50,6 +50,8 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_where_with_sca array_api_tests/test_creation_functions.py::test_meshgrid array_api_tests/test_data_type_functions.py::test_broadcast_arrays +# observed with numpy==1.26 only, looks like is fixed on numpy 2.x +array_api_tests/test_set_functions.py::TestIsin::test_isin_scalars # Stubs have a comment: (**note**: libraries may return ``NaN`` to match Python behavior.); Apparently, NumPy does just that array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity] From 35b06e5f51008baf9bcd08938edb93a1f031969f Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Tue, 24 Feb 2026 13:03:16 +0100 Subject: [PATCH 126/151] BUG: torch: expand_dims axis is keyword or positional https://data-apis.org/array-api/2025.12/API_specification/generated/array_api.expand_dims.html --- array_api_compat/torch/_aliases.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index f97a9c6f..69bd3763 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -694,7 +694,7 @@ def triu(x: Array, /, *, k: int = 0) -> Array: return torch.triu(x, k) # Functions that aren't in torch https://github.com/pytorch/pytorch/issues/58742 -def expand_dims(x: Array, /, *, axis: int | tuple[int, ...]) -> Array: +def expand_dims(x: Array, /, axis: int | tuple[int, ...]) -> Array: if isinstance(axis, int): return torch.unsqueeze(x, axis) else: From f8097331f29064c978d8fe2f46fabada8578bdcc Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Tue, 24 Feb 2026 14:55:32 +0100 Subject: [PATCH 127/151] ENH: cupy: make isin accept int scalars --- array_api_compat/cupy/_aliases.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index 7b7bfda6..44808ec9 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -167,6 +167,15 @@ def searchsorted( return cp.searchsorted(x1, x2, side, sorter) +# CuPy isin does not accept scalars +def isin(x1: Array | int, x2: Array | int, /, *, invert: bool = False, **kwds) -> Array: + if isinstance(x1, int): + x1 = cp.asarray(x1) + if isinstance(x2, int): + x2 = cp.asarray(x2) + return cp.isin(x1, x2, invert=invert, **kwds) + + # These functions are completely new here. If the library already has them # (i.e., numpy 2.0), use the library version instead of our wrapper. if hasattr(cp, 'vecdot'): @@ -191,7 +200,7 @@ def searchsorted( 'bool', 'concat', 'count_nonzero', 'pow', 'sign', 'ceil', 'floor', 'trunc', 'take_along_axis', 'broadcast_arrays', 'meshgrid', - 'searchsorted', + 'searchsorted', 'isin', ] From 3818fa9ab55a1195cff83925ac6987cabf1675c4 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Tue, 24 Feb 2026 18:27:15 +0000 Subject: [PATCH 128/151] ENH: bump wrapped API versions to 2025.12 --- array_api_compat/cupy/__init__.py | 2 +- array_api_compat/dask/array/__init__.py | 2 +- array_api_compat/numpy/__init__.py | 2 +- array_api_compat/torch/__init__.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/array_api_compat/cupy/__init__.py b/array_api_compat/cupy/__init__.py index af003c5a..558a83e1 100644 --- a/array_api_compat/cupy/__init__.py +++ b/array_api_compat/cupy/__init__.py @@ -12,7 +12,7 @@ __import__(__package__ + '.linalg') __import__(__package__ + '.fft') -__array_api_version__: Final = '2024.12' +__array_api_version__: Final = '2025.12' __all__ = sorted( {name for name in globals() if not name.startswith("__")} diff --git a/array_api_compat/dask/array/__init__.py b/array_api_compat/dask/array/__init__.py index f78aa8b3..1905c671 100644 --- a/array_api_compat/dask/array/__init__.py +++ b/array_api_compat/dask/array/__init__.py @@ -9,7 +9,7 @@ from ._aliases import * # type: ignore[assignment] # noqa: F403 from ._info import __array_namespace_info__ # noqa: F401 -__array_api_version__: Final = "2024.12" +__array_api_version__: Final = "2025.12" del Final # See the comment in the numpy __init__.py diff --git a/array_api_compat/numpy/__init__.py b/array_api_compat/numpy/__init__.py index 23379e44..81eaafef 100644 --- a/array_api_compat/numpy/__init__.py +++ b/array_api_compat/numpy/__init__.py @@ -26,7 +26,7 @@ from .linalg import matrix_transpose, vecdot # type: ignore[no-redef] # noqa: F401 -__array_api_version__: Final = "2024.12" +__array_api_version__: Final = "2025.12" __all__ = sorted( set(__all__) diff --git a/array_api_compat/torch/__init__.py b/array_api_compat/torch/__init__.py index 6cbb6ec2..8263faa6 100644 --- a/array_api_compat/torch/__init__.py +++ b/array_api_compat/torch/__init__.py @@ -13,7 +13,7 @@ __import__(__package__ + '.linalg') __import__(__package__ + '.fft') -__array_api_version__: Final = '2024.12' +__array_api_version__: Final = '2025.12' __all__ = sorted( set(__all__) From 12e2e5bcbb52d56a0e903ce4a4a02d74c3f69a6a Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Tue, 24 Feb 2026 18:42:07 +0000 Subject: [PATCH 129/151] ENH: bump version to 2025.12 in common/_helpers.py --- array_api_compat/common/_helpers.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index bc90d208..8a307f9d 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -55,8 +55,8 @@ | SupportsArrayNamespace[Any] ) -_API_VERSIONS_OLD: Final = frozenset({"2021.12", "2022.12", "2023.12"}) -_API_VERSIONS: Final = _API_VERSIONS_OLD | frozenset({"2024.12"}) +_API_VERSIONS_OLD: Final = frozenset({"2021.12", "2022.12", "2023.12", "2024.12"}) +_API_VERSIONS: Final = _API_VERSIONS_OLD | frozenset({"2025.12"}) @lru_cache(100) @@ -485,11 +485,11 @@ def is_array_api_strict_namespace(xp: Namespace) -> bool: def _check_api_version(api_version: str | None) -> None: if api_version in _API_VERSIONS_OLD: warnings.warn( - f"The {api_version} version of the array API specification was requested but the returned namespace is actually version 2024.12" + f"The {api_version} version of the array API specification was requested but the returned namespace is actually version 2025.12" ) elif api_version is not None and api_version not in _API_VERSIONS: raise ValueError( - "Only the 2024.12 version of the array API specification is currently supported" + "Only the 2025.12 version of the array API specification is currently supported" ) @@ -589,7 +589,7 @@ def array_namespace( api_version: str The newest version of the spec that you need support for (currently - the compat library wrapped APIs support v2024.12). + the compat library wrapped APIs support v2025.12). use_compat: bool or None If None (the default), the native namespace will be returned if it is From 4f8a4aa1bb1f3d56348248a33e6a66d7905db100 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Tue, 24 Feb 2026 18:48:16 +0000 Subject: [PATCH 130/151] TST: bump the API version in test_cupy.sh --- test_cupy.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test_cupy.sh b/test_cupy.sh index a6974333..8ac72dc6 100755 --- a/test_cupy.sh +++ b/test_cupy.sh @@ -26,5 +26,5 @@ mkdir -p $SCRIPT_DIR/.hypothesis ln -s $SCRIPT_DIR/.hypothesis .hypothesis export ARRAY_API_TESTS_MODULE=array_api_compat.cupy -export ARRAY_API_TESTS_VERSION=2024.12 +export ARRAY_API_TESTS_VERSION=2025.12 pytest array_api_tests/ ${PYTEST_ARGS} --xfails-file $SCRIPT_DIR/cupy-xfails.txt "$@" From a5a3cf15465cdef1905e37b8b8ffa0157078e6c5 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Thu, 26 Feb 2026 08:54:54 +0100 Subject: [PATCH 131/151] TST: add 2025 names to test_all --- tests/test_all.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/test_all.py b/tests/test_all.py index c36aef67..d9350ce7 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -140,6 +140,7 @@ # Manipulation Functions "broadcast_arrays", "broadcast_to", + "broadcast_shapes", "concat", "expand_dims", "flip", @@ -164,6 +165,7 @@ "unique_counts", "unique_inverse", "unique_values", + "isin", # Sorting Functions "argsort", "sort", @@ -205,6 +207,8 @@ "diagonal", "eigh", "eigvalsh", + "eig", + "eigvals", "inv", "matmul", "matrix_norm", @@ -227,12 +231,14 @@ XFAILS = { ("numpy", ""): ["from_dlpack"] if np.__version__ < "1.23" else [], - ("dask.array", ""): ["from_dlpack", "take_along_axis"], + ("dask.array", ""): ["from_dlpack", "take_along_axis", "broadcast_shapes"], ("dask.array", "linalg"): [ "cross", "det", "eigh", "eigvalsh", + "eig", + "eigvals", "matrix_power", "pinv", "slogdet", From 2ddef55a7ea097c43fd90010bb3bc5f60372feb0 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Thu, 26 Feb 2026 09:32:46 +0000 Subject: [PATCH 132/151] BUG: cupy/linalg: include non-array api names --- array_api_compat/cupy/linalg.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/array_api_compat/cupy/linalg.py b/array_api_compat/cupy/linalg.py index da301574..1943cb15 100644 --- a/array_api_compat/cupy/linalg.py +++ b/array_api_compat/cupy/linalg.py @@ -1,13 +1,25 @@ from cupy.linalg import * # noqa: F403 -# cupy.linalg doesn't have __all__. If it is added, replace this with + +# https://github.com/cupy/cupy/issues/9749 +from cupy.linalg import lstsq # noqa: F401 + +# cupy.linalg doesn't have __all__ in cupy<14. If it is added, replace this with # # from cupy.linalg import __all__ as linalg_all _n: dict[str, object] = {} exec('from cupy.linalg import *', _n) del _n['__builtins__'] -linalg_all = list(_n) +linalg_all = list(_n) + ['lstsq'] del _n +try: + # cupy 14 exports it, cupy 13 does not + from cupy.linalg import annotations # noqa: F401 + linalg_all += ['annotations'] +except ImportError: + pass + + from ..common import _linalg from .._internal import get_xp @@ -43,5 +55,8 @@ __all__ = linalg_all + _linalg.__all__ +# cupy 13 does not have __all__, cupy 14 has it: remove duplicates +__all__ = sorted(list(set(__all__))) + def __dir__() -> list[str]: return __all__ From 95e44b14d39d30fb296a169bb9b9bd5462dc3c1a Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Thu, 26 Feb 2026 12:12:11 +0100 Subject: [PATCH 133/151] DOC: add 1.14 changelog --- docs/changelog.md | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/docs/changelog.md b/docs/changelog.md index a9b355a0..527dd3bd 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -1,5 +1,37 @@ # Changelog +## 1.14.0 (2026-02-26) + +### Major changes + +This release targets the 2025.12 Array API revision. This includes + + - `__array_api_version__` for the wrapped APIs is now set to `2025.12`; + - wrappers for `linalg.eig` and `linalg.eigvals`; + - wrappers for `isin` and `searchsorted` to accept Python scalars; + - wrappers for `expand_dims` accepting tuple axes; + - `broadcast_arrays`, `meshgrid` and `__array_api_info__().devices()` have been + changed to return tuples, not lists; + +Additionally, + + - `clip` wrappers have been fixed to be compatible with `torch.vmap`. + + +### Minor changes + + - `expand_dims` wrappers have been fixed to accept its `axis` argument as a keyword + or positional argument; + - `torch.clip` wrappers have been fixed to correctly handle `nan` scalars; + - `torch.repeat` wrapper has been fixed to not error out for short integers; + + +The following users contributed to this release: + +Evgeni Burovski, +Josh Soref. + + ## 1.13.0 (2025-12-28) From 145bd7d3f41d0cf98f0a7f35ba7e255874dec948 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Thu, 26 Feb 2026 12:49:58 +0100 Subject: [PATCH 134/151] REL: bump the version number to 1.14.0 --- array_api_compat/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_compat/__init__.py b/array_api_compat/__init__.py index a28101d6..785659a0 100644 --- a/array_api_compat/__init__.py +++ b/array_api_compat/__init__.py @@ -17,6 +17,6 @@ this implementation for the default when working with NumPy arrays. """ -__version__ = '1.14.0.dev0' +__version__ = '1.14.0' from .common import * # noqa: F401, F403 From 551682145580039de2b8733a79b2658e06e8ad8a Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Thu, 26 Feb 2026 13:06:04 +0100 Subject: [PATCH 135/151] MAINT: bump the version to 1.15.0.dev0 --- array_api_compat/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_compat/__init__.py b/array_api_compat/__init__.py index 785659a0..e7480cce 100644 --- a/array_api_compat/__init__.py +++ b/array_api_compat/__init__.py @@ -17,6 +17,6 @@ this implementation for the default when working with NumPy arrays. """ -__version__ = '1.14.0' +__version__ = '1.15.0.dev0' from .common import * # noqa: F401, F403 From 3e6e4d360c572e86a94302ed7d3d5f69aeefbe08 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 2 Mar 2026 10:06:22 +0000 Subject: [PATCH 136/151] Bump the actions group with 3 updates Bumps the actions group with 3 updates: [actions/upload-artifact](https://github.com/actions/upload-artifact), [dawidd6/action-download-artifact](https://github.com/dawidd6/action-download-artifact) and [actions/download-artifact](https://github.com/actions/download-artifact). Updates `actions/upload-artifact` from 6 to 7 - [Release notes](https://github.com/actions/upload-artifact/releases) - [Commits](https://github.com/actions/upload-artifact/compare/v6...v7) Updates `dawidd6/action-download-artifact` from 15 to 16 - [Release notes](https://github.com/dawidd6/action-download-artifact/releases) - [Commits](https://github.com/dawidd6/action-download-artifact/compare/v15...v16) Updates `actions/download-artifact` from 7 to 8 - [Release notes](https://github.com/actions/download-artifact/releases) - [Commits](https://github.com/actions/download-artifact/compare/v7...v8) --- updated-dependencies: - dependency-name: actions/upload-artifact dependency-version: '7' dependency-type: direct:production update-type: version-update:semver-major dependency-group: actions - dependency-name: dawidd6/action-download-artifact dependency-version: '16' dependency-type: direct:production update-type: version-update:semver-major dependency-group: actions - dependency-name: actions/download-artifact dependency-version: '8' dependency-type: direct:production update-type: version-update:semver-major dependency-group: actions ... Signed-off-by: dependabot[bot] --- .github/workflows/docs-build.yml | 2 +- .github/workflows/docs-deploy.yml | 2 +- .github/workflows/publish-package.yml | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/docs-build.yml b/.github/workflows/docs-build.yml index 91a68457..b0c39822 100644 --- a/.github/workflows/docs-build.yml +++ b/.github/workflows/docs-build.yml @@ -16,7 +16,7 @@ jobs: cd docs make html - name: Upload Artifact - uses: actions/upload-artifact@v6 + uses: actions/upload-artifact@v7 with: name: docs-build path: docs/_build/html diff --git a/.github/workflows/docs-deploy.yml b/.github/workflows/docs-deploy.yml index d237595a..b5d062cd 100644 --- a/.github/workflows/docs-deploy.yml +++ b/.github/workflows/docs-deploy.yml @@ -13,7 +13,7 @@ jobs: steps: - uses: actions/checkout@v6 - name: Download Artifact - uses: dawidd6/action-download-artifact@v15 + uses: dawidd6/action-download-artifact@v16 with: workflow: docs-build.yml name: docs-build diff --git a/.github/workflows/publish-package.yml b/.github/workflows/publish-package.yml index ce47aed5..885e5d2c 100644 --- a/.github/workflows/publish-package.yml +++ b/.github/workflows/publish-package.yml @@ -60,7 +60,7 @@ jobs: run: python -m zipfile --list dist/array_api_compat-*.whl - name: Upload distribution artifact - uses: actions/upload-artifact@v6 + uses: actions/upload-artifact@v7 with: name: dist-artifact path: dist @@ -81,7 +81,7 @@ jobs: steps: - name: Download distribution artifact - uses: actions/download-artifact@v7 + uses: actions/download-artifact@v8 with: name: dist-artifact path: dist From e1fda8e65afb9fb48c1da79ff96d9d9a3aa2cbe7 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 16 Mar 2026 09:50:40 +0000 Subject: [PATCH 137/151] Bump dawidd6/action-download-artifact from 16 to 18 in the actions group Bumps the actions group with 1 update: [dawidd6/action-download-artifact](https://github.com/dawidd6/action-download-artifact). Updates `dawidd6/action-download-artifact` from 16 to 18 - [Release notes](https://github.com/dawidd6/action-download-artifact/releases) - [Commits](https://github.com/dawidd6/action-download-artifact/compare/v16...v18) --- updated-dependencies: - dependency-name: dawidd6/action-download-artifact dependency-version: '18' dependency-type: direct:production update-type: version-update:semver-major dependency-group: actions ... Signed-off-by: dependabot[bot] --- .github/workflows/docs-deploy.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/docs-deploy.yml b/.github/workflows/docs-deploy.yml index b5d062cd..ccee809c 100644 --- a/.github/workflows/docs-deploy.yml +++ b/.github/workflows/docs-deploy.yml @@ -13,7 +13,7 @@ jobs: steps: - uses: actions/checkout@v6 - name: Download Artifact - uses: dawidd6/action-download-artifact@v16 + uses: dawidd6/action-download-artifact@v18 with: workflow: docs-build.yml name: docs-build From 4078863f0195959df736588bc2402bb034cce34c Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Tue, 17 Mar 2026 10:26:28 +0100 Subject: [PATCH 138/151] BUG: torch: work around torch.round not supporting complex inputs --- array_api_compat/torch/_aliases.py | 18 +++++++++++++++++- tests/test_torch.py | 9 +++++++++ torch-xfails.txt | 1 - 3 files changed, 26 insertions(+), 2 deletions(-) diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 69bd3763..a5348dc2 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -912,6 +912,22 @@ def sign(x: Array, /) -> Array: return out +def round(x: Array, /, **kwargs) -> Array: + # torch.round fails for complex inputs + # https://github.com/pytorch/pytorch/issues/58743#issuecomment-2727603845 + if x.dtype.is_complex: + out = kwargs.pop('out', None) + res_r = torch.round(x.real, **kwargs) + res_i = torch.round(x.imag, **kwargs) + res = res_r + 1j*res_i + if out is not None: + out.copy_(res) + return out + return res + else: + return torch.round(x, **kwargs) + + def meshgrid(*arrays: Array, indexing: Literal['xy', 'ij'] = 'xy') -> tuple[Array, ...]: # torch <= 2.9 emits a UserWarning: "torch.meshgrid: in an upcoming release, it # will be required to pass the indexing argument." @@ -923,7 +939,7 @@ def meshgrid(*arrays: Array, indexing: Literal['xy', 'ij'] = 'xy') -> tuple[Arra 'permute_dims', 'bitwise_invert', 'newaxis', 'conj', 'add', 'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'copysign', 'count_nonzero', - 'diff', 'divide', + 'diff', 'divide', 'round', 'equal', 'floor_divide', 'greater', 'greater_equal', 'hypot', 'less', 'less_equal', 'logaddexp', 'maximum', 'minimum', 'multiply', 'not_equal', 'pow', 'remainder', 'subtract', 'max', diff --git a/tests/test_torch.py b/tests/test_torch.py index 463dd597..35ef5dda 100644 --- a/tests/test_torch.py +++ b/tests/test_torch.py @@ -152,3 +152,12 @@ def test_argsort_stable(): t = xp.zeros(50) # should be >16 assert xp.all(xp.argsort(t) == xp.arange(50)) + + +def test_round(): + """Verify the out= argument of xp.round with complex inputs.""" + x = torch.as_tensor([1.23456786]*3) + 3.456789j + o = torch.empty(3, dtype=torch.complex64) + r = xp.round(x, decimals=1, out=o) + assert xp.all(r == o) + assert r is o diff --git a/torch-xfails.txt b/torch-xfails.txt index 84271a56..3b75972b 100644 --- a/torch-xfails.txt +++ b/torch-xfails.txt @@ -130,7 +130,6 @@ array_api_tests/test_statistical_functions.py::test_var # These functions do not yet support complex numbers -array_api_tests/test_operators_and_elementwise_functions.py::test_round array_api_tests/test_set_functions.py::test_unique_counts array_api_tests/test_set_functions.py::test_unique_values From ac7e9976329a1692f88d7c614557810405aec2e3 Mon Sep 17 00:00:00 2001 From: Lucas Colley Date: Tue, 17 Mar 2026 11:17:09 +0000 Subject: [PATCH 139/151] BUG: torch.arange: workaround for missing dtype implementations (#405) * BUG: torch.arange: workaround for missing dtype implementations reviewed at https://github.com/data-apis/array-api-compat/pull/405 --- array_api_compat/torch/_aliases.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 69bd3763..61618f70 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -616,8 +616,12 @@ def arange(start: float, dtype = torch.int64 else: dtype = torch.float32 - return torch.empty(0, dtype=dtype, device=device, **kwargs) - return torch.arange(start, stop, step, dtype=dtype, device=device, **kwargs) + return torch.empty(0, device=device, **kwargs).to(dtype) + try: + return torch.arange(start, stop, step, dtype=dtype, device=device, **kwargs) + # torch 2.7 raises RuntimeError, 2.9 emits NotImplementedError + except (NotImplementedError, RuntimeError): + return torch.arange(start, stop, step, device=device, **kwargs).to(dtype) # torch.eye does not accept None as a default for the second argument and # doesn't support off-diagonals (https://github.com/pytorch/pytorch/issues/70910) From 5d9cc213bfdc6826bc2621383d807d9e0c03adc0 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 23 Mar 2026 09:49:49 +0000 Subject: [PATCH 140/151] Bump dawidd6/action-download-artifact from 18 to 19 in the actions group Bumps the actions group with 1 update: [dawidd6/action-download-artifact](https://github.com/dawidd6/action-download-artifact). Updates `dawidd6/action-download-artifact` from 18 to 19 - [Release notes](https://github.com/dawidd6/action-download-artifact/releases) - [Commits](https://github.com/dawidd6/action-download-artifact/compare/v18...v19) --- updated-dependencies: - dependency-name: dawidd6/action-download-artifact dependency-version: '19' dependency-type: direct:production update-type: version-update:semver-major dependency-group: actions ... Signed-off-by: dependabot[bot] --- .github/workflows/docs-deploy.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/docs-deploy.yml b/.github/workflows/docs-deploy.yml index ccee809c..307c48fe 100644 --- a/.github/workflows/docs-deploy.yml +++ b/.github/workflows/docs-deploy.yml @@ -13,7 +13,7 @@ jobs: steps: - uses: actions/checkout@v6 - name: Download Artifact - uses: dawidd6/action-download-artifact@v18 + uses: dawidd6/action-download-artifact@v19 with: workflow: docs-build.yml name: docs-build From f51d83ce6a03ee122868a27d6c7a07d96fb9bf55 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Fri, 3 Apr 2026 13:15:42 +0200 Subject: [PATCH 141/151] TST: skip dlpack tests on numpy 1.2x, dask, and cupy --- cupy-xfails.txt | 5 +++++ dask-xfails.txt | 5 +++++ numpy-1-22-xfails.txt | 5 +++++ numpy-1-26-xfails.txt | 5 ++++- 4 files changed, 19 insertions(+), 1 deletion(-) diff --git a/cupy-xfails.txt b/cupy-xfails.txt index 7e147635..a32e382c 100644 --- a/cupy-xfails.txt +++ b/cupy-xfails.txt @@ -217,3 +217,8 @@ array_api_tests/test_special_cases.py::test_binary[nextafter(x1_i is +0 and x2_i # cupy 13.x follows numpy 1.x w/o weak promotion: result_type(int32, uint8, 1) != result_type(int32, uint8) array_api_tests/test_data_type_functions.py::TestResultType::test_with_scalars + +# CuPy 14 does not support copy= dlpack argument +array_api_tests/test_dlpack.py::test_dunder_dlpack +array_api_tests/test_dlpack.py::test_from_dlpack + diff --git a/dask-xfails.txt b/dask-xfails.txt index c2c54af2..34d3afd6 100644 --- a/dask-xfails.txt +++ b/dask-xfails.txt @@ -173,3 +173,8 @@ array_api_tests/test_special_cases.py::test_unary[sqrt(real(x_i) is +infinity an array_api_tests/test_special_cases.py::test_unary[tanh(real(x_i) is +infinity and isfinite(imag(x_i)) and imag(x_i) > 0) -> 1 + 0j] array_api_tests/test_special_cases.py::test_unary[expm1((real(x_i) is +0 or real(x_i) == -0) and imag(x_i) is +0) -> 0 + 0j] array_api_tests/test_special_cases.py::test_unary[acosh(real(x_i) is +0 and imag(x_i) is NaN) -> NaN \xb1 \u03c0j/2] + +# no dlpack support +array_api_tests/test_dlpack.py::test_dlpack_device +array_api_tests/test_dlpack.py::test_dunder_dlpack +array_api_tests/test_dlpack.py::test_from_dlpack diff --git a/numpy-1-22-xfails.txt b/numpy-1-22-xfails.txt index c2e13d35..20477f99 100644 --- a/numpy-1-22-xfails.txt +++ b/numpy-1-22-xfails.txt @@ -191,3 +191,8 @@ array_api_tests/test_special_cases.py::test_unary[expm1((real(x_i) is +0 or real array_api_tests/test_special_cases.py::test_unary[sqrt(real(x_i) is +infinity and isfinite(imag(x_i)) and imag(x_i) > 0) -> +0 + infinity j] array_api_tests/test_special_cases.py::test_unary[acosh(real(x_i) is +0 and imag(x_i) is NaN) -> NaN \xb1 \u03c0j/2] array_api_tests/test_special_cases.py::test_unary[tanh(real(x_i) is +infinity and isfinite(imag(x_i)) and imag(x_i) > 0) -> 1 + 0j] + +# no dlpack support +array_api_tests/test_dlpack.py::test_dlpack_device +array_api_tests/test_dlpack.py::test_dunder_dlpack +array_api_tests/test_dlpack.py::test_from_dlpack diff --git a/numpy-1-26-xfails.txt b/numpy-1-26-xfails.txt index 831de9e4..45d62f23 100644 --- a/numpy-1-26-xfails.txt +++ b/numpy-1-26-xfails.txt @@ -85,6 +85,9 @@ array_api_tests/test_special_cases.py::test_unary[sqrt(real(x_i) is +infinity an array_api_tests/test_special_cases.py::test_unary[acosh(real(x_i) is +0 and imag(x_i) is NaN) -> NaN \xb1 \u03c0j/2] array_api_tests/test_special_cases.py::test_unary[tanh(real(x_i) is +infinity and isfinite(imag(x_i)) and imag(x_i) > 0) -> 1 + 0j] - +# no dlpack support +array_api_tests/test_dlpack.py::test_dlpack_device +array_api_tests/test_dlpack.py::test_dunder_dlpack +array_api_tests/test_dlpack.py::test_from_dlpack From 75ee8bc2c5efee29301e38eba09707b507cd955d Mon Sep 17 00:00:00 2001 From: Lucas Colley Date: Fri, 3 Apr 2026 19:22:11 +0100 Subject: [PATCH 142/151] ENH: array_namespace: support `torch.compile` (#413) * ENH: array_namespace: support `torch.compile` without graph breaks torch.dynamo does not know that module objects are hashable, and graph breaks on compiling a set of module variables. Use lists instead to avoid a graph break. --------- Co-authored-by: Evgeni Burovski --- array_api_compat/common/_helpers.py | 16 ++++++++++++++-- tests/test_no_dependencies.py | 2 +- tests/test_torch.py | 15 +++++++++++++++ 3 files changed, 30 insertions(+), 3 deletions(-) diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 8a307f9d..10134367 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -641,7 +641,7 @@ def your_function(x, y): is_pydata_sparse_array """ - namespaces: set[Namespace] = set() + namespaces: list[Namespace] = [] for x in xs: xp, info = _cls_to_namespace(cast(Hashable, type(x)), api_version, use_compat) if info is _ClsToXPInfo.SCALAR: @@ -663,7 +663,19 @@ def your_function(x, y): ) xp = get_ns(api_version=api_version) - namespaces.add(xp) + namespaces.append(xp) + + # Use a list of modules to avoid a graph break under torch.compile: + # torch._dynamo.exc.Unsupported: Dynamo cannot determine whether the underlying object is hashable + # Explanation: Dynamo does not know whether the underlying python object for + # PythonModuleVariable( Date: Sun, 5 Apr 2026 18:12:53 +0200 Subject: [PATCH 143/151] =?UTF-8?q?MAINT:=20=5F=5Fpackage=5F=5F=20?= =?UTF-8?q?=E2=86=92=20=5F=5Fspec=5F=5F.parent?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Remove deprecated __package__, scheduled for removal in Python 3.15: https://docs.python.org/3.15/reference/datamodel.html#module.__package__ --- array_api_compat/cupy/__init__.py | 4 ++-- array_api_compat/dask/array/__init__.py | 4 ++-- array_api_compat/numpy/__init__.py | 4 ++-- array_api_compat/torch/__init__.py | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/array_api_compat/cupy/__init__.py b/array_api_compat/cupy/__init__.py index 558a83e1..246ac872 100644 --- a/array_api_compat/cupy/__init__.py +++ b/array_api_compat/cupy/__init__.py @@ -9,8 +9,8 @@ from ._info import __array_namespace_info__ # noqa: F401 # See the comment in the numpy __init__.py -__import__(__package__ + '.linalg') -__import__(__package__ + '.fft') +__import__(__spec__.parent + '.linalg') +__import__(__spec__.parent + '.fft') __array_api_version__: Final = '2025.12' diff --git a/array_api_compat/dask/array/__init__.py b/array_api_compat/dask/array/__init__.py index 1905c671..d25ae513 100644 --- a/array_api_compat/dask/array/__init__.py +++ b/array_api_compat/dask/array/__init__.py @@ -13,8 +13,8 @@ del Final # See the comment in the numpy __init__.py -__import__(__package__ + '.linalg') -__import__(__package__ + '.fft') +__import__(__spec__.parent + '.linalg') +__import__(__spec__.parent + '.fft') __all__ = sorted( set(__all__) diff --git a/array_api_compat/numpy/__init__.py b/array_api_compat/numpy/__init__.py index 81eaafef..bda4356f 100644 --- a/array_api_compat/numpy/__init__.py +++ b/array_api_compat/numpy/__init__.py @@ -20,9 +20,9 @@ # # It doesn't overwrite np.linalg from above. The import is generated # dynamically so that the library can be vendored. -__import__(__package__ + ".linalg") +__import__(__spec__.parent + ".linalg") -__import__(__package__ + ".fft") +__import__(__spec__.parent + ".fft") from .linalg import matrix_transpose, vecdot # type: ignore[no-redef] # noqa: F401 diff --git a/array_api_compat/torch/__init__.py b/array_api_compat/torch/__init__.py index 8263faa6..c5c801aa 100644 --- a/array_api_compat/torch/__init__.py +++ b/array_api_compat/torch/__init__.py @@ -10,8 +10,8 @@ from ._info import __array_namespace_info__ # noqa: F401 # See the comment in the numpy __init__.py -__import__(__package__ + '.linalg') -__import__(__package__ + '.fft') +__import__(__spec__.parent + '.linalg') +__import__(__spec__.parent + '.fft') __array_api_version__: Final = '2025.12' From 6e1bd6d8f11892edeb172a07678c3c4986e45e6b Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 6 Apr 2026 09:47:07 +0000 Subject: [PATCH 144/151] Bump dawidd6/action-download-artifact from 19 to 20 in the actions group Bumps the actions group with 1 update: [dawidd6/action-download-artifact](https://github.com/dawidd6/action-download-artifact). Updates `dawidd6/action-download-artifact` from 19 to 20 - [Release notes](https://github.com/dawidd6/action-download-artifact/releases) - [Commits](https://github.com/dawidd6/action-download-artifact/compare/v19...v20) --- updated-dependencies: - dependency-name: dawidd6/action-download-artifact dependency-version: '20' dependency-type: direct:production update-type: version-update:semver-major dependency-group: actions ... Signed-off-by: dependabot[bot] --- .github/workflows/docs-deploy.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/docs-deploy.yml b/.github/workflows/docs-deploy.yml index 307c48fe..f96b39ec 100644 --- a/.github/workflows/docs-deploy.yml +++ b/.github/workflows/docs-deploy.yml @@ -13,7 +13,7 @@ jobs: steps: - uses: actions/checkout@v6 - name: Download Artifact - uses: dawidd6/action-download-artifact@v19 + uses: dawidd6/action-download-artifact@v20 with: workflow: docs-build.yml name: docs-build From e5c981bd2a76858a53f1966b5ae2f5107265732a Mon Sep 17 00:00:00 2001 From: Dimitri Papadopoulos Orfanos <3234522+DimitriPapadopoulos@users.noreply.github.com> Date: Wed, 8 Apr 2026 11:18:05 +0300 Subject: [PATCH 145/151] Fix typos (#417) --- array_api_compat/common/_helpers.py | 2 +- docs/changelog.md | 2 +- torch-xfails.txt | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 10134367..b43e3d22 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -1054,7 +1054,7 @@ def is_lazy_array(x: object) -> TypeGuard[_ArrayApiObj]: try: bool(x) return False - # The Array API standard dictactes that __bool__ should raise TypeError if the + # The Array API standard dictates that __bool__ should raise TypeError if the # output cannot be defined. # Here we allow for it to raise arbitrary exceptions, e.g. like Dask does. except Exception: diff --git a/docs/changelog.md b/docs/changelog.md index 527dd3bd..55449df0 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -257,7 +257,7 @@ Thomas Li than setting the conjugation bit, as arrays with the conjugation bit set do not support some APIs. -- `torch.sign` is now wrapped to support complex numbers and propogate nans +- `torch.sign` is now wrapped to support complex numbers and propagate nans properly. ### Minor Changes diff --git a/torch-xfails.txt b/torch-xfails.txt index 3b75972b..22aeed8f 100644 --- a/torch-xfails.txt +++ b/torch-xfails.txt @@ -6,7 +6,7 @@ # Indexing does not support negative step array_api_tests/test_array_object.py::test_getitem array_api_tests/test_array_object.py::test_setitem -# Masking doesn't suport 0 dimensions in the mask +# Masking doesn't support 0 dimensions in the mask array_api_tests/test_array_object.py::test_getitem_masking # Overflow error from large inputs From dce7b884f0c406c3962a330d99b62931900fc873 Mon Sep 17 00:00:00 2001 From: Dimitri Papadopoulos Orfanos <3234522+DimitriPapadopoulos@users.noreply.github.com> Date: Wed, 8 Apr 2026 11:19:40 +0300 Subject: [PATCH 146/151] MAINT: document Python 3.14 support in trove classifiers (#418) --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index ec054417..073074ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,7 @@ classifiers = [ "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", "Topic :: Software Development :: Libraries :: Python Modules", "Typing :: Typed", ] From a0efa5ccb62e19c7f34d3a93c3d2bf091274405f Mon Sep 17 00:00:00 2001 From: Dimitri Papadopoulos Orfanos <3234522+DimitriPapadopoulos@users.noreply.github.com> Date: Sun, 12 Apr 2026 18:22:33 +0300 Subject: [PATCH 147/151] Merge pull request #416 from DimitriPapadopoulos/ruff MAINT: Enforce additional ruff rules --- array_api_compat/__init__.py | 2 +- array_api_compat/common/_helpers.py | 2 +- array_api_compat/cupy/linalg.py | 2 +- array_api_compat/numpy/__init__.py | 1 - array_api_compat/torch/_aliases.py | 2 +- array_api_compat/torch/fft.py | 20 ++++++++++---------- pyproject.toml | 24 +++++++++++++++++------- tests/test_isdtype.py | 2 +- tests/test_no_dependencies.py | 4 ++-- 9 files changed, 34 insertions(+), 25 deletions(-) diff --git a/array_api_compat/__init__.py b/array_api_compat/__init__.py index e7480cce..454c7f8b 100644 --- a/array_api_compat/__init__.py +++ b/array_api_compat/__init__.py @@ -19,4 +19,4 @@ """ __version__ = '1.15.0.dev0' -from .common import * # noqa: F401, F403 +from .common import * # noqa: F403 diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index b43e3d22..c154ad0b 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -669,7 +669,7 @@ def your_function(x, y): # torch._dynamo.exc.Unsupported: Dynamo cannot determine whether the underlying object is hashable # Explanation: Dynamo does not know whether the underlying python object for # PythonModuleVariable( list[str]: return __all__ diff --git a/array_api_compat/numpy/__init__.py b/array_api_compat/numpy/__init__.py index bda4356f..973e993d 100644 --- a/array_api_compat/numpy/__init__.py +++ b/array_api_compat/numpy/__init__.py @@ -1,4 +1,3 @@ -# ruff: noqa: PLC0414 from typing import Final from .._internal import clone_module diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 88936302..5969e4db 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -278,7 +278,7 @@ def _axis_none_keepdims(x, ndim, keepdims): # (https://github.com/pytorch/pytorch/issues/71209) # Note that this is only valid for the axis=None case. if keepdims: - for i in range(ndim): + for _ in range(ndim): x = torch.unsqueeze(x, 0) return x diff --git a/array_api_compat/torch/fft.py b/array_api_compat/torch/fft.py index f11b3eb5..0fa6ea9a 100644 --- a/array_api_compat/torch/fft.py +++ b/array_api_compat/torch/fft.py @@ -17,8 +17,8 @@ def fftn( x: Array, /, *, - s: Sequence[int] = None, - axes: Sequence[int] = None, + s: Sequence[int] | None = None, + axes: Sequence[int] | None = None, norm: Literal["backward", "ortho", "forward"] = "backward", **kwargs: object, ) -> Array: @@ -28,8 +28,8 @@ def ifftn( x: Array, /, *, - s: Sequence[int] = None, - axes: Sequence[int] = None, + s: Sequence[int] | None = None, + axes: Sequence[int] | None = None, norm: Literal["backward", "ortho", "forward"] = "backward", **kwargs: object, ) -> Array: @@ -39,8 +39,8 @@ def rfftn( x: Array, /, *, - s: Sequence[int] = None, - axes: Sequence[int] = None, + s: Sequence[int] | None = None, + axes: Sequence[int] | None = None, norm: Literal["backward", "ortho", "forward"] = "backward", **kwargs: object, ) -> Array: @@ -50,8 +50,8 @@ def irfftn( x: Array, /, *, - s: Sequence[int] = None, - axes: Sequence[int] = None, + s: Sequence[int] | None = None, + axes: Sequence[int] | None = None, norm: Literal["backward", "ortho", "forward"] = "backward", **kwargs: object, ) -> Array: @@ -61,7 +61,7 @@ def fftshift( x: Array, /, *, - axes: int | Sequence[int] = None, + axes: int | Sequence[int] | None = None, **kwargs: object, ) -> Array: return torch.fft.fftshift(x, dim=axes, **kwargs) @@ -70,7 +70,7 @@ def ifftshift( x: Array, /, *, - axes: int | Sequence[int] = None, + axes: int | Sequence[int] | None = None, **kwargs: object, ) -> Array: return torch.fft.ifftshift(x, dim=axes, **kwargs) diff --git a/pyproject.toml b/pyproject.toml index 073074ca..d7339170 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,19 +65,29 @@ namespaces = false [tool.ruff.lint] preview = true select = [ -# Defaults -"E4", "E7", "E9", "F", -# Undefined export -"F822", -# Useless import alias -"PLC0414" + # Defaults + "E4", "E7", "E9", "F", + # Additional rules + "B", "C4", "ISC", "PIE", "FLY", "PERF", "UP", "FURB", + # Useless import alias + "PLC0414", + # Unused `noqa` directive + "RUF100", ] ignore = [ # Module import not at top of file "E402", # Do not use bare `except` - "E722" + "E722", + # Use of `functools.cache` on methods can lead to memory leaks + "B019", + # No explicit `stacklevel` keyword argument found + "B028", + # Within an `except` clause, raise exceptions with `raise ... from ...` + "B904", + # `try`-`except` within a loop incurs performance overhead + "PERF203", ] diff --git a/tests/test_isdtype.py b/tests/test_isdtype.py index 6ad45d4c..92e45613 100644 --- a/tests/test_isdtype.py +++ b/tests/test_isdtype.py @@ -61,7 +61,7 @@ def isdtype_(dtype_, kind): res = dtype_categories[kind](dtype_) else: res = dtype_ == kind - assert type(res) is bool # noqa: E721 + assert type(res) is bool return res @pytest.mark.parametrize("library", wrapped_libraries) diff --git a/tests/test_no_dependencies.py b/tests/test_no_dependencies.py index c53780b2..624f8971 100644 --- a/tests/test_no_dependencies.py +++ b/tests/test_no_dependencies.py @@ -38,11 +38,11 @@ def _test_dependency(mod): assert not is_mod_array(a) assert mod not in sys.modules - is_array_api_obj = getattr(array_api_compat, "is_array_api_obj") + is_array_api_obj = array_api_compat.is_array_api_obj assert is_array_api_obj(a) assert mod not in sys.modules - array_namespace = getattr(array_api_compat, "array_namespace") + array_namespace = array_api_compat.array_namespace array_namespace(Array()) assert mod not in sys.modules From c0a506137a446e9241d52a2f96fc34acf15280b7 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 13 Apr 2026 10:07:49 +0000 Subject: [PATCH 148/151] Bump the actions group with 3 updates Bumps the actions group with 3 updates: [dependabot/fetch-metadata](https://github.com/dependabot/fetch-metadata), [pypa/gh-action-pypi-publish](https://github.com/pypa/gh-action-pypi-publish) and [softprops/action-gh-release](https://github.com/softprops/action-gh-release). Updates `dependabot/fetch-metadata` from 2 to 3 - [Release notes](https://github.com/dependabot/fetch-metadata/releases) - [Commits](https://github.com/dependabot/fetch-metadata/compare/v2...v3) Updates `pypa/gh-action-pypi-publish` from 1.13.0 to 1.14.0 - [Release notes](https://github.com/pypa/gh-action-pypi-publish/releases) - [Commits](https://github.com/pypa/gh-action-pypi-publish/compare/v1.13.0...v1.14.0) Updates `softprops/action-gh-release` from 2 to 3 - [Release notes](https://github.com/softprops/action-gh-release/releases) - [Changelog](https://github.com/softprops/action-gh-release/blob/master/CHANGELOG.md) - [Commits](https://github.com/softprops/action-gh-release/compare/v2...v3) --- updated-dependencies: - dependency-name: dependabot/fetch-metadata dependency-version: '3' dependency-type: direct:production update-type: version-update:semver-major dependency-group: actions - dependency-name: pypa/gh-action-pypi-publish dependency-version: 1.14.0 dependency-type: direct:production update-type: version-update:semver-minor dependency-group: actions - dependency-name: softprops/action-gh-release dependency-version: '3' dependency-type: direct:production update-type: version-update:semver-major dependency-group: actions ... Signed-off-by: dependabot[bot] --- .github/workflows/dependabot-auto-merge.yml | 2 +- .github/workflows/publish-package.yml | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/dependabot-auto-merge.yml b/.github/workflows/dependabot-auto-merge.yml index bd29e25b..48dfc680 100644 --- a/.github/workflows/dependabot-auto-merge.yml +++ b/.github/workflows/dependabot-auto-merge.yml @@ -13,7 +13,7 @@ jobs: steps: - name: Dependabot metadata id: metadata - uses: dependabot/fetch-metadata@v2 + uses: dependabot/fetch-metadata@v3 with: github-token: "${{ secrets.GITHUB_TOKEN }}" - name: Enable auto-merge for Dependabot PRs diff --git a/.github/workflows/publish-package.yml b/.github/workflows/publish-package.yml index 885e5d2c..826c5239 100644 --- a/.github/workflows/publish-package.yml +++ b/.github/workflows/publish-package.yml @@ -95,19 +95,19 @@ jobs: # if: >- # (github.event_name == 'push' && startsWith(github.ref, 'refs/tags')) # || (github.event_name == 'workflow_dispatch' && github.event.inputs.publish == 'true') - # uses: pypa/gh-action-pypi-publish@v1.13.0 + # uses: pypa/gh-action-pypi-publish@v1.14.0 # with: # repository-url: https://test.pypi.org/legacy/ # print-hash: true - name: Publish distribution 📦 to PyPI if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags') - uses: pypa/gh-action-pypi-publish@v1.13.0 + uses: pypa/gh-action-pypi-publish@v1.14.0 with: print-hash: true - name: Create GitHub Release from a Tag - uses: softprops/action-gh-release@v2 + uses: softprops/action-gh-release@v3 if: startsWith(github.ref, 'refs/tags/') with: files: dist/* From 3c26d639253da489fb1ca586ea44f6ad525dd4a1 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Thu, 23 Apr 2026 09:03:15 +0200 Subject: [PATCH 149/151] CI: add dependabot cooldown period Add a 7-day cooldown period for Dependabot updates --- .github/dependabot.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 5b5616f1..2fcdc68d 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -5,6 +5,8 @@ updates: directory: "/" schedule: interval: "weekly" + cooldown: + default-days: 7 # optional groups: actions: patterns: From 3bfeb43773e224ed7eecab5203d7eeb9bad9fe9e Mon Sep 17 00:00:00 2001 From: Chris Ninham <61634310+Nin17@users.noreply.github.com> Date: Tue, 28 Apr 2026 21:11:42 +0200 Subject: [PATCH 150/151] BUG: torch.meshgrid RuntimeError if no arrays (#425) * BUG: torch RuntimeError if no arrays * TST: torch.meshgrid no arrays * BUG: torch.meshgrid check indexing first - consistency with numpy+jax.numpy --- array_api_compat/torch/_aliases.py | 4 +++- tests/test_torch.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 5969e4db..e27c3ca2 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -936,7 +936,9 @@ def meshgrid(*arrays: Array, indexing: Literal['xy', 'ij'] = 'xy') -> tuple[Arra # torch <= 2.9 emits a UserWarning: "torch.meshgrid: in an upcoming release, it # will be required to pass the indexing argument." # Thus always pass it explicitly. - return torch.meshgrid(*arrays, indexing=indexing) + if indexing not in ("xy", "ij"): + raise ValueError(f'torch.meshgrid: indexing must be one of "xy" or "ij", but received: {indexing}') + return torch.meshgrid(*arrays, indexing=indexing) if arrays else () __all__ = ['asarray', 'result_type', 'can_cast', diff --git a/tests/test_torch.py b/tests/test_torch.py index b064a46d..3d6ebc46 100644 --- a/tests/test_torch.py +++ b/tests/test_torch.py @@ -115,7 +115,7 @@ def apply_clip_compat(a): def test_meshgrid(): - """Verify that array_api_compat.torch.meshgrid defaults to indexing='xy'.""" + """Verify that array_api_compat.torch.meshgrid defaults to indexing='xy', and supports passing no arrays.""" x, y = xp.asarray([1, 2]), xp.asarray([4]) @@ -142,6 +142,8 @@ def test_meshgrid(): assert Y.shape == Y_ij.shape assert xp.all(Y == Y_ij) + assert not xp.meshgrid() + def test_argsort_stable(): """Verify that argsort defaults to a stable sort.""" From 907116c0261bf2f0498a7acdc05533031a81a108 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Wed, 29 Apr 2026 08:45:16 +0200 Subject: [PATCH 151/151] BUG: np.matrix is not an array API object (#423) Being an ndarray subclass, np.matrix inherits the `__array_namespace__` method, even though it is not Array API compatible, and is not meant to be. ``` (Pdb) x matrix([[3]]) (Pdb) p x.__array_namespace__() is np True ``` --- array_api_compat/common/_helpers.py | 8 +++++++- tests/test_numpy.py | 19 +++++++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) create mode 100644 tests/test_numpy.py diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index c154ad0b..d5342658 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -288,8 +288,14 @@ def is_array_api_obj(x: object) -> TypeGuard[_ArrayApiObj]: is_dask_array is_jax_array """ + try: + # TODO: drop this check after np.matrix is gone + if _issubclass_fast(type(x), "numpy", "matrix"): + return False + except Exception: + pass return ( - hasattr(x, '__array_namespace__') + hasattr(x, '__array_namespace__') or _is_array_api_cls(cast(Hashable, type(x))) ) diff --git a/tests/test_numpy.py b/tests/test_numpy.py new file mode 100644 index 00000000..a139d428 --- /dev/null +++ b/tests/test_numpy.py @@ -0,0 +1,19 @@ +"""Test "unspecified" behavior which we cannot easily test in the Array API test suite. +""" +import warnings +import pytest + +try: + import numpy as np +except ImportError: + pytestmark = pytest.skip(allow_module_level=True, reason="numpy not found") + +from array_api_compat import is_array_api_obj + +def test_matrix_is_not_array_api_obj(): + assert is_array_api_obj(np.asarray(3)) + assert is_array_api_obj(np.float64(3)) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", PendingDeprecationWarning) + assert not is_array_api_obj(np.matrix(3))