diff --git a/numpy/__init__.pyi b/numpy/__init__.pyi index 075b2f4a0c98..7fe5172e4bb9 100644 --- a/numpy/__init__.pyi +++ b/numpy/__init__.pyi @@ -739,7 +739,10 @@ type _Truthy = L[True, 1] | bool_[L[True]] type _1D = tuple[int] type _2D = tuple[int, int] +type _3D = tuple[int, int, int] + type _2Tuple[T] = tuple[T, T] +type _3Tuple[T] = tuple[T, T, T] type _ArrayUInt_co = NDArray[unsignedinteger | bool_] type _ArrayInt_co = NDArray[integer | bool_] @@ -2414,8 +2417,17 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]): @overload def dot[ArrayT: ndarray](self, b: ArrayLike, /, out: ArrayT) -> ArrayT: ... - # `nonzero()` raises for 0d arrays/generics - def nonzero(self) -> tuple[ndarray[tuple[int], _dtype[intp]], ...]: ... + # keep in sync with `_core.fromnumeric.nonzero` + @overload # ?d (workaround) + def nonzero(self: ndarray[tuple[Never, Never, Never, Never]]) -> tuple[ndarray[_1D, _dtype[intp]], ...]: ... + @overload # 1d + def nonzero(self: ndarray[_1D]) -> tuple[ndarray[_1D, _dtype[intp]]]: ... + @overload # 2d + def nonzero(self: ndarray[_2D]) -> _2Tuple[ndarray[_1D, _dtype[intp]]]: ... + @overload # 3d + def nonzero(self: ndarray[_3D]) -> _3Tuple[ndarray[_1D, _dtype[intp]]]: ... + @overload # 3d + def nonzero(self) -> tuple[ndarray[_1D, _dtype[intp]], ...]: ... @overload def searchsorted( diff --git a/numpy/_core/fromnumeric.pyi b/numpy/_core/fromnumeric.pyi index 6a42651fa87d..23d61b97500e 100644 --- a/numpy/_core/fromnumeric.pyi +++ b/numpy/_core/fromnumeric.pyi @@ -129,6 +129,14 @@ type _3D = tuple[int, int, int] type _4D = tuple[int, int, int, int] type _Array1D[ScalarT: np.generic] = np.ndarray[_1D, np.dtype[ScalarT]] +type _Array2D[ScalarT: np.generic] = np.ndarray[_2D, np.dtype[ScalarT]] +type _Array3D[ScalarT: np.generic] = np.ndarray[_3D, np.dtype[ScalarT]] +# workaround for mypy's and pyright's typing spec non-compliance regarding overloads +type _ArrayJustND[ScalarT: np.generic] = np.ndarray[tuple[Never, Never, Never, Never], np.dtype[ScalarT]] + +type _ToArray1D[ScalarT: np.generic] = _Array1D[ScalarT] | Sequence[ScalarT] +type _ToArray2D[ScalarT: np.generic] = _Array2D[ScalarT] | Sequence[Sequence[ScalarT]] +type _ToArray3D[ScalarT: np.generic] = _Array3D[ScalarT] | Sequence[Sequence[Sequence[ScalarT]]] ### @@ -658,7 +666,17 @@ def ravel(a: complex | _NestedSequence[complex], order: _OrderKACF = "C") -> _Ar @overload def ravel(a: ArrayLike, order: _OrderKACF = "C") -> np.ndarray[_1D]: ... -def nonzero(a: _ArrayLike[Any]) -> tuple[_Array1D[np.intp], ...]: ... +# keep in sync with the 1-arg overloads of `_core.multiarray.where` +@overload # ?d (workaround) +def nonzero(a: _ArrayJustND[Any]) -> tuple[_Array1D[np.intp], ...]: ... +@overload # 1d +def nonzero(a: _ToArray1D[Any]) -> tuple[_Array1D[np.intp]]: ... +@overload # 2d +def nonzero(a: _ToArray2D[Any]) -> tuple[_Array1D[np.intp], _Array1D[np.intp]]: ... +@overload # 3d +def nonzero(a: _ToArray3D[Any]) -> tuple[_Array1D[np.intp], _Array1D[np.intp], _Array1D[np.intp]]: ... +@overload # Nd (fallback) +def nonzero(a: _ArrayLike[Any]) -> tuple[_Array1D[np.intp], ...]: ... # this prevents `Any` from being returned with Pyright @overload diff --git a/numpy/_core/multiarray.pyi b/numpy/_core/multiarray.pyi index d0844a245590..226859eef5fe 100644 --- a/numpy/_core/multiarray.pyi +++ b/numpy/_core/multiarray.pyi @@ -185,11 +185,13 @@ _ArrayT_co = TypeVar("_ArrayT_co", bound=np.ndarray, default=np.ndarray, covaria type _Array[ShapeT: _Shape, ScalarT: np.generic] = ndarray[ShapeT, dtype[ScalarT]] type _Array1D[ScalarT: np.generic] = ndarray[tuple[int], dtype[ScalarT]] type _Array2D[ScalarT: np.generic] = ndarray[tuple[int, int], dtype[ScalarT]] +type _Array3D[ScalarT: np.generic] = ndarray[tuple[int, int, int], dtype[ScalarT]] # workaround for mypy's and pyright's typing spec non-compliance regarding overloads -type _ArrayJustND[ScalarT: np.generic] = ndarray[tuple[Never, Never, Never], dtype[ScalarT]] +type _ArrayJustND[ScalarT: np.generic] = ndarray[tuple[Never, Never, Never, Never], dtype[ScalarT]] type _ToArray1D[ScalarT: np.generic] = _Array1D[ScalarT] | Sequence[ScalarT] type _ToArray2D[ScalarT: np.generic] = _Array2D[ScalarT] | Sequence[Sequence[ScalarT]] +type _ToArray3D[ScalarT: np.generic] = _Array3D[ScalarT] | Sequence[Sequence[Sequence[ScalarT]]] # Valid time units type _UnitKind = L[ @@ -763,12 +765,23 @@ def dot(a: ArrayLike, b: ArrayLike, out: None = None) -> Incomplete: ... @overload def dot[OutT: np.ndarray](a: ArrayLike, b: ArrayLike, out: OutT) -> OutT: ... -# keep in sync with `ma.core.where` -@overload -def where(condition: ArrayLike, x: None = None, y: None = None, /) -> tuple[NDArray[intp], ...]: ... +# keep in sync with `ma.core.where` and the 1-arg overloads with `_core.fromnumeric.nonzerp` +@overload # (?d) (workaround) +def where(condition: _ArrayJustND[Any], x: None = None, y: None = None, /) -> tuple[_Array1D[np.intp], ...]: ... +@overload # (1d) +def where(condition: _ToArray1D[Any], x: None = None, y: None = None, /) -> tuple[_Array1D[np.intp]]: ... +@overload # (2d) +def where(condition: _ToArray2D[Any], x: None = None, y: None = None, /) -> tuple[_Array1D[np.intp], _Array1D[np.intp]]: ... +@overload # (3d) +def where( + condition: _ToArray3D[Any], x: None = None, y: None = None, / +) -> tuple[_Array1D[np.intp], _Array1D[np.intp], _Array1D[np.intp]]: ... +@overload # (Nd) (fallback) +def where(condition: _ArrayLike[Any], x: None = None, y: None = None, /) -> tuple[_Array1D[np.intp], ...]: ... @overload def where(condition: ArrayLike, x: ArrayLike, y: ArrayLike, /) -> NDArray[Incomplete]: ... +# def lexsort(keys: ArrayLike, axis: SupportsIndex = -1) -> NDArray[intp]: ... def can_cast(from_: ArrayLike | DTypeLike, to: DTypeLike, casting: _CastingKind = "safe") -> bool: ... diff --git a/numpy/typing/tests/data/fail/fromnumeric.pyi b/numpy/typing/tests/data/fail/fromnumeric.pyi index 7f98ab4602c2..5655ede23e6e 100644 --- a/numpy/typing/tests/data/fail/fromnumeric.pyi +++ b/numpy/typing/tests/data/fail/fromnumeric.pyi @@ -75,7 +75,7 @@ np.trace(A, axis2=[]) # type: ignore[call-overload] np.ravel(a, order="bob") # type: ignore[call-overload] -np.nonzero(0) # type: ignore[arg-type] +np.nonzero(0) # type: ignore[call-overload] np.compress([True], A, axis=1.0) # type: ignore[call-overload] diff --git a/numpy/typing/tests/data/reveal/fromnumeric.pyi b/numpy/typing/tests/data/reveal/fromnumeric.pyi index 042690fb4367..2048ace97d46 100644 --- a/numpy/typing/tests/data/reveal/fromnumeric.pyi +++ b/numpy/typing/tests/data/reveal/fromnumeric.pyi @@ -11,6 +11,7 @@ AR_b: npt.NDArray[np.bool] AR_f4: npt.NDArray[np.float32] AR_f4_1d: np.ndarray[tuple[int], np.dtype[np.float32]] AR_f4_2d: np.ndarray[tuple[int, int], np.dtype[np.float32]] +AR_f4_3d: np.ndarray[tuple[int, int, int], np.dtype[np.float32]] AR_c16: npt.NDArray[np.complex128] AR_u8: npt.NDArray[np.uint64] AR_i8: npt.NDArray[np.int64] @@ -155,10 +156,12 @@ assert_type(np.ravel(f), np.ndarray[tuple[int], np.dtype[np.float64 | Any]]) assert_type(np.ravel(AR_b), np.ndarray[tuple[int], np.dtype[np.bool]]) assert_type(np.ravel(AR_f4), np.ndarray[tuple[int], np.dtype[np.float32]]) -assert_type(np.nonzero(AR_b), tuple[np.ndarray[tuple[int], np.dtype[np.intp]], ...]) -assert_type(np.nonzero(AR_f4), tuple[np.ndarray[tuple[int], np.dtype[np.intp]], ...]) -assert_type(np.nonzero(AR_1d), tuple[np.ndarray[tuple[int], np.dtype[np.intp]], ...]) -assert_type(np.nonzero(AR_nd), tuple[np.ndarray[tuple[int], np.dtype[np.intp]], ...]) +type _Int1D = np.ndarray[tuple[int], np.dtype[np.intp]] + +assert_type(np.nonzero(AR_f4), tuple[_Int1D, ...]) +assert_type(np.nonzero(AR_f4_1d), tuple[_Int1D]) +assert_type(np.nonzero(AR_f4_2d), tuple[_Int1D, _Int1D]) +assert_type(np.nonzero(AR_f4_3d), tuple[_Int1D, _Int1D, _Int1D]) assert_type(np.shape(b), tuple[()]) assert_type(np.shape(f), tuple[()]) diff --git a/numpy/typing/tests/data/reveal/ma.pyi b/numpy/typing/tests/data/reveal/ma.pyi index 944983b97fa1..18bd3acd916d 100644 --- a/numpy/typing/tests/data/reveal/ma.pyi +++ b/numpy/typing/tests/data/reveal/ma.pyi @@ -424,8 +424,7 @@ assert_type(MAR_2d_f4.dot(1), MaskedArray[Any]) assert_type(MAR_2d_f4.dot([1]), MaskedArray[Any]) assert_type(MAR_2d_f4.dot(1, out=MAR_subclass), MaskedArraySubclassC) -assert_type(MAR_2d_f4.nonzero(), tuple[_Array1D[np.intp], ...]) -assert_type(MAR_2d_f4.nonzero()[0], _Array1D[np.intp]) +assert_type(MAR_2d_f4.nonzero(), tuple[_Array1D[np.intp], _Array1D[np.intp]]) assert_type(MAR_f8.trace(), Any) assert_type(MAR_f8.trace(out=MAR_subclass), MaskedArraySubclassC) diff --git a/numpy/typing/tests/data/reveal/multiarray.pyi b/numpy/typing/tests/data/reveal/multiarray.pyi index f3d7c0749fc7..2c8bab482ff1 100644 --- a/numpy/typing/tests/data/reveal/multiarray.pyi +++ b/numpy/typing/tests/data/reveal/multiarray.pyi @@ -11,6 +11,7 @@ subclass: SubClass[np.float64] AR_f4_nd: npt.NDArray[np.float32] AR_f4_1d: np.ndarray[tuple[int], np.dtype[np.float32]] AR_f4_2d: np.ndarray[tuple[int, int], np.dtype[np.float32]] +AR_f4_3d: np.ndarray[tuple[int, int, int], np.dtype[np.float32]] AR_f8: npt.NDArray[np.float64] AR_c16: npt.NDArray[np.complex128] AR_i8: npt.NDArray[np.int64] @@ -153,7 +154,52 @@ assert_type(np.dot(AR_O_nd, AR_O_nd), Any) # -assert_type(np.where([True, True, False]), tuple[npt.NDArray[np.intp], ...]) +assert_type(np.dot(AR_LIKE_b, AR_LIKE_b), np.bool) +assert_type(np.dot(AR_LIKE_b, AR_LIKE_i), np.int_) +assert_type(np.dot(AR_LIKE_b, AR_LIKE_f), np.float64) +assert_type(np.dot(AR_LIKE_b, AR_LIKE_c), np.complex128) +assert_type(np.dot(AR_LIKE_i, AR_LIKE_b), np.int_) +assert_type(np.dot(AR_LIKE_i, AR_LIKE_i), np.int_) +assert_type(np.dot(AR_LIKE_i, AR_LIKE_f), np.float64) +assert_type(np.dot(AR_LIKE_i, AR_LIKE_c), np.complex128) +assert_type(np.dot(AR_LIKE_f, AR_LIKE_b), np.float64) +assert_type(np.dot(AR_LIKE_f, AR_LIKE_i), np.float64) +assert_type(np.dot(AR_LIKE_f, AR_LIKE_f), np.float64) +assert_type(np.dot(AR_LIKE_f, AR_LIKE_c), np.complex128) +assert_type(np.dot(AR_LIKE_c, AR_LIKE_b), np.complex128) +assert_type(np.dot(AR_LIKE_c, AR_LIKE_i), np.complex128) +assert_type(np.dot(AR_LIKE_c, AR_LIKE_f), np.complex128) +assert_type(np.dot(AR_LIKE_c, AR_LIKE_c), np.complex128) + +assert_type(np.dot(AR_f4_1d, AR_f4_1d), np.float32) +assert_type(np.dot(AR_f4_1d, AR_f4_2d), np.ndarray[tuple[int], np.dtype[np.float32]]) +assert_type(np.dot(AR_f4_1d, AR_f4_nd), Any) +assert_type(np.dot(AR_f4_2d, AR_f4_1d), np.ndarray[tuple[int], np.dtype[np.float32]]) +assert_type(np.dot(AR_f4_2d, AR_f4_2d), np.ndarray[tuple[int, int], np.dtype[np.float32]]) +assert_type(np.dot(AR_f4_2d, AR_f4_nd), Any) +assert_type(np.dot(AR_f4_nd, AR_f4_1d), Any) +assert_type(np.dot(AR_f4_nd, AR_f4_2d), Any) +assert_type(np.dot(AR_f4_nd, AR_f4_nd), Any) + +assert_type(np.dot(AR_O_1d, AR_O_1d), Any) +assert_type(np.dot(AR_O_1d, AR_O_2d), np.ndarray[tuple[int], np.dtype[np.object_]]) +assert_type(np.dot(AR_O_1d, AR_O_nd), Any) +assert_type(np.dot(AR_O_2d, AR_O_1d), np.ndarray[tuple[int], np.dtype[np.object_]]) +assert_type(np.dot(AR_O_2d, AR_O_2d), np.ndarray[tuple[int, int], np.dtype[np.object_]]) +assert_type(np.dot(AR_O_2d, AR_O_nd), Any) +assert_type(np.dot(AR_O_nd, AR_O_1d), Any) +assert_type(np.dot(AR_O_nd, AR_O_2d), Any) +assert_type(np.dot(AR_O_nd, AR_O_nd), Any) + +# + +type _Int1D = np.ndarray[tuple[int], np.dtype[np.intp]] + +assert_type(np.where([True, True, False]), tuple[_Int1D,]) +assert_type(np.where(AR_f4_1d), tuple[_Int1D]) +assert_type(np.where(AR_f4_2d), tuple[_Int1D, _Int1D]) +assert_type(np.where(AR_f4_3d), tuple[_Int1D, _Int1D, _Int1D]) +assert_type(np.where(AR_f4_nd), tuple[_Int1D, ...]) assert_type(np.where([True, True, False], 1, 0), npt.NDArray[Any]) assert_type(np.lexsort([0, 1, 2]), npt.NDArray[np.intp]) diff --git a/numpy/typing/tests/data/reveal/ndarray_misc.pyi b/numpy/typing/tests/data/reveal/ndarray_misc.pyi index fa2c6020919f..b58472f08d49 100644 --- a/numpy/typing/tests/data/reveal/ndarray_misc.pyi +++ b/numpy/typing/tests/data/reveal/ndarray_misc.pyi @@ -171,7 +171,12 @@ assert_type(AR_f8.dot(1), npt.NDArray[Any]) assert_type(AR_f8.dot([1]), Any) assert_type(AR_f8.dot(1, out=B), SubClass) -assert_type(AR_f8.nonzero(), tuple[np.ndarray[tuple[int], np.dtype[np.intp]], ...]) +type _Int1D = np.ndarray[tuple[int], np.dtype[np.intp]] + +assert_type(AR_f8.nonzero(), tuple[_Int1D, ...]) +assert_type(AR_f8_1d.nonzero(), tuple[_Int1D]) +assert_type(AR_f8_2d.nonzero(), tuple[_Int1D, _Int1D]) +assert_type(AR_f8_3d.nonzero(), tuple[_Int1D, _Int1D, _Int1D]) assert_type(AR_f8.searchsorted(1), np.intp) assert_type(AR_f8.searchsorted([1]), npt.NDArray[np.intp])