Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions numpy/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -739,7 +739,10 @@ type _Truthy = L[True, 1] | bool_[L[True]]

type _1D = tuple[int]
type _2D = tuple[int, int]
type _3D = tuple[int, int, int]

type _2Tuple[T] = tuple[T, T]
type _3Tuple[T] = tuple[T, T, T]

type _ArrayUInt_co = NDArray[unsignedinteger | bool_]
type _ArrayInt_co = NDArray[integer | bool_]
Expand Down Expand Up @@ -2414,8 +2417,17 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]):
@overload
def dot[ArrayT: ndarray](self, b: ArrayLike, /, out: ArrayT) -> ArrayT: ...

# `nonzero()` raises for 0d arrays/generics
def nonzero(self) -> tuple[ndarray[tuple[int], _dtype[intp]], ...]: ...
# keep in sync with `_core.fromnumeric.nonzero`
@overload # ?d (workaround)
def nonzero(self: ndarray[tuple[Never, Never, Never, Never]]) -> tuple[ndarray[_1D, _dtype[intp]], ...]: ...
@overload # 1d
def nonzero(self: ndarray[_1D]) -> tuple[ndarray[_1D, _dtype[intp]]]: ...
@overload # 2d
def nonzero(self: ndarray[_2D]) -> _2Tuple[ndarray[_1D, _dtype[intp]]]: ...
@overload # 3d
def nonzero(self: ndarray[_3D]) -> _3Tuple[ndarray[_1D, _dtype[intp]]]: ...
@overload # 3d
def nonzero(self) -> tuple[ndarray[_1D, _dtype[intp]], ...]: ...

@overload
def searchsorted(
Expand Down
20 changes: 19 additions & 1 deletion numpy/_core/fromnumeric.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,14 @@ type _3D = tuple[int, int, int]
type _4D = tuple[int, int, int, int]

type _Array1D[ScalarT: np.generic] = np.ndarray[_1D, np.dtype[ScalarT]]
type _Array2D[ScalarT: np.generic] = np.ndarray[_2D, np.dtype[ScalarT]]
type _Array3D[ScalarT: np.generic] = np.ndarray[_3D, np.dtype[ScalarT]]
# workaround for mypy's and pyright's typing spec non-compliance regarding overloads
type _ArrayJustND[ScalarT: np.generic] = np.ndarray[tuple[Never, Never, Never, Never], np.dtype[ScalarT]]

type _ToArray1D[ScalarT: np.generic] = _Array1D[ScalarT] | Sequence[ScalarT]
type _ToArray2D[ScalarT: np.generic] = _Array2D[ScalarT] | Sequence[Sequence[ScalarT]]
type _ToArray3D[ScalarT: np.generic] = _Array3D[ScalarT] | Sequence[Sequence[Sequence[ScalarT]]]

###

Expand Down Expand Up @@ -658,7 +666,17 @@ def ravel(a: complex | _NestedSequence[complex], order: _OrderKACF = "C") -> _Ar
@overload
def ravel(a: ArrayLike, order: _OrderKACF = "C") -> np.ndarray[_1D]: ...

def nonzero(a: _ArrayLike[Any]) -> tuple[_Array1D[np.intp], ...]: ...
# keep in sync with the 1-arg overloads of `_core.multiarray.where`
@overload # ?d (workaround)
def nonzero(a: _ArrayJustND[Any]) -> tuple[_Array1D[np.intp], ...]: ...
@overload # 1d
def nonzero(a: _ToArray1D[Any]) -> tuple[_Array1D[np.intp]]: ...
@overload # 2d
def nonzero(a: _ToArray2D[Any]) -> tuple[_Array1D[np.intp], _Array1D[np.intp]]: ...
@overload # 3d
def nonzero(a: _ToArray3D[Any]) -> tuple[_Array1D[np.intp], _Array1D[np.intp], _Array1D[np.intp]]: ...
@overload # Nd (fallback)
def nonzero(a: _ArrayLike[Any]) -> tuple[_Array1D[np.intp], ...]: ...

# this prevents `Any` from being returned with Pyright
@overload
Expand Down
21 changes: 17 additions & 4 deletions numpy/_core/multiarray.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -185,11 +185,13 @@ _ArrayT_co = TypeVar("_ArrayT_co", bound=np.ndarray, default=np.ndarray, covaria
type _Array[ShapeT: _Shape, ScalarT: np.generic] = ndarray[ShapeT, dtype[ScalarT]]
type _Array1D[ScalarT: np.generic] = ndarray[tuple[int], dtype[ScalarT]]
type _Array2D[ScalarT: np.generic] = ndarray[tuple[int, int], dtype[ScalarT]]
type _Array3D[ScalarT: np.generic] = ndarray[tuple[int, int, int], dtype[ScalarT]]
# workaround for mypy's and pyright's typing spec non-compliance regarding overloads
type _ArrayJustND[ScalarT: np.generic] = ndarray[tuple[Never, Never, Never], dtype[ScalarT]]
type _ArrayJustND[ScalarT: np.generic] = ndarray[tuple[Never, Never, Never, Never], dtype[ScalarT]]

type _ToArray1D[ScalarT: np.generic] = _Array1D[ScalarT] | Sequence[ScalarT]
type _ToArray2D[ScalarT: np.generic] = _Array2D[ScalarT] | Sequence[Sequence[ScalarT]]
type _ToArray3D[ScalarT: np.generic] = _Array3D[ScalarT] | Sequence[Sequence[Sequence[ScalarT]]]

# Valid time units
type _UnitKind = L[
Expand Down Expand Up @@ -763,12 +765,23 @@ def dot(a: ArrayLike, b: ArrayLike, out: None = None) -> Incomplete: ...
@overload
def dot[OutT: np.ndarray](a: ArrayLike, b: ArrayLike, out: OutT) -> OutT: ...

# keep in sync with `ma.core.where`
@overload
def where(condition: ArrayLike, x: None = None, y: None = None, /) -> tuple[NDArray[intp], ...]: ...
# keep in sync with `ma.core.where` and the 1-arg overloads with `_core.fromnumeric.nonzerp`
@overload # (?d) (workaround)
def where(condition: _ArrayJustND[Any], x: None = None, y: None = None, /) -> tuple[_Array1D[np.intp], ...]: ...
@overload # (1d)
def where(condition: _ToArray1D[Any], x: None = None, y: None = None, /) -> tuple[_Array1D[np.intp]]: ...
@overload # (2d)
def where(condition: _ToArray2D[Any], x: None = None, y: None = None, /) -> tuple[_Array1D[np.intp], _Array1D[np.intp]]: ...
@overload # (3d)
def where(
condition: _ToArray3D[Any], x: None = None, y: None = None, /
) -> tuple[_Array1D[np.intp], _Array1D[np.intp], _Array1D[np.intp]]: ...
@overload # (Nd) (fallback)
def where(condition: _ArrayLike[Any], x: None = None, y: None = None, /) -> tuple[_Array1D[np.intp], ...]: ...
@overload
def where(condition: ArrayLike, x: ArrayLike, y: ArrayLike, /) -> NDArray[Incomplete]: ...

#
def lexsort(keys: ArrayLike, axis: SupportsIndex = -1) -> NDArray[intp]: ...

def can_cast(from_: ArrayLike | DTypeLike, to: DTypeLike, casting: _CastingKind = "safe") -> bool: ...
Expand Down
2 changes: 1 addition & 1 deletion numpy/typing/tests/data/fail/fromnumeric.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ np.trace(A, axis2=[]) # type: ignore[call-overload]

np.ravel(a, order="bob") # type: ignore[call-overload]

np.nonzero(0) # type: ignore[arg-type]
np.nonzero(0) # type: ignore[call-overload]

np.compress([True], A, axis=1.0) # type: ignore[call-overload]

Expand Down
11 changes: 7 additions & 4 deletions numpy/typing/tests/data/reveal/fromnumeric.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ AR_b: npt.NDArray[np.bool]
AR_f4: npt.NDArray[np.float32]
AR_f4_1d: np.ndarray[tuple[int], np.dtype[np.float32]]
AR_f4_2d: np.ndarray[tuple[int, int], np.dtype[np.float32]]
AR_f4_3d: np.ndarray[tuple[int, int, int], np.dtype[np.float32]]
AR_c16: npt.NDArray[np.complex128]
AR_u8: npt.NDArray[np.uint64]
AR_i8: npt.NDArray[np.int64]
Expand Down Expand Up @@ -155,10 +156,12 @@ assert_type(np.ravel(f), np.ndarray[tuple[int], np.dtype[np.float64 | Any]])
assert_type(np.ravel(AR_b), np.ndarray[tuple[int], np.dtype[np.bool]])
assert_type(np.ravel(AR_f4), np.ndarray[tuple[int], np.dtype[np.float32]])

assert_type(np.nonzero(AR_b), tuple[np.ndarray[tuple[int], np.dtype[np.intp]], ...])
assert_type(np.nonzero(AR_f4), tuple[np.ndarray[tuple[int], np.dtype[np.intp]], ...])
assert_type(np.nonzero(AR_1d), tuple[np.ndarray[tuple[int], np.dtype[np.intp]], ...])
assert_type(np.nonzero(AR_nd), tuple[np.ndarray[tuple[int], np.dtype[np.intp]], ...])
type _Int1D = np.ndarray[tuple[int], np.dtype[np.intp]]

assert_type(np.nonzero(AR_f4), tuple[_Int1D, ...])
assert_type(np.nonzero(AR_f4_1d), tuple[_Int1D])
assert_type(np.nonzero(AR_f4_2d), tuple[_Int1D, _Int1D])
assert_type(np.nonzero(AR_f4_3d), tuple[_Int1D, _Int1D, _Int1D])

assert_type(np.shape(b), tuple[()])
assert_type(np.shape(f), tuple[()])
Expand Down
3 changes: 1 addition & 2 deletions numpy/typing/tests/data/reveal/ma.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -424,8 +424,7 @@ assert_type(MAR_2d_f4.dot(1), MaskedArray[Any])
assert_type(MAR_2d_f4.dot([1]), MaskedArray[Any])
assert_type(MAR_2d_f4.dot(1, out=MAR_subclass), MaskedArraySubclassC)

assert_type(MAR_2d_f4.nonzero(), tuple[_Array1D[np.intp], ...])
assert_type(MAR_2d_f4.nonzero()[0], _Array1D[np.intp])
assert_type(MAR_2d_f4.nonzero(), tuple[_Array1D[np.intp], _Array1D[np.intp]])

assert_type(MAR_f8.trace(), Any)
assert_type(MAR_f8.trace(out=MAR_subclass), MaskedArraySubclassC)
Expand Down
48 changes: 47 additions & 1 deletion numpy/typing/tests/data/reveal/multiarray.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ subclass: SubClass[np.float64]
AR_f4_nd: npt.NDArray[np.float32]
AR_f4_1d: np.ndarray[tuple[int], np.dtype[np.float32]]
AR_f4_2d: np.ndarray[tuple[int, int], np.dtype[np.float32]]
AR_f4_3d: np.ndarray[tuple[int, int, int], np.dtype[np.float32]]
AR_f8: npt.NDArray[np.float64]
AR_c16: npt.NDArray[np.complex128]
AR_i8: npt.NDArray[np.int64]
Expand Down Expand Up @@ -153,7 +154,52 @@ assert_type(np.dot(AR_O_nd, AR_O_nd), Any)

#

assert_type(np.where([True, True, False]), tuple[npt.NDArray[np.intp], ...])
assert_type(np.dot(AR_LIKE_b, AR_LIKE_b), np.bool)
assert_type(np.dot(AR_LIKE_b, AR_LIKE_i), np.int_)
assert_type(np.dot(AR_LIKE_b, AR_LIKE_f), np.float64)
assert_type(np.dot(AR_LIKE_b, AR_LIKE_c), np.complex128)
assert_type(np.dot(AR_LIKE_i, AR_LIKE_b), np.int_)
assert_type(np.dot(AR_LIKE_i, AR_LIKE_i), np.int_)
assert_type(np.dot(AR_LIKE_i, AR_LIKE_f), np.float64)
assert_type(np.dot(AR_LIKE_i, AR_LIKE_c), np.complex128)
assert_type(np.dot(AR_LIKE_f, AR_LIKE_b), np.float64)
assert_type(np.dot(AR_LIKE_f, AR_LIKE_i), np.float64)
assert_type(np.dot(AR_LIKE_f, AR_LIKE_f), np.float64)
assert_type(np.dot(AR_LIKE_f, AR_LIKE_c), np.complex128)
assert_type(np.dot(AR_LIKE_c, AR_LIKE_b), np.complex128)
assert_type(np.dot(AR_LIKE_c, AR_LIKE_i), np.complex128)
assert_type(np.dot(AR_LIKE_c, AR_LIKE_f), np.complex128)
assert_type(np.dot(AR_LIKE_c, AR_LIKE_c), np.complex128)

assert_type(np.dot(AR_f4_1d, AR_f4_1d), np.float32)
assert_type(np.dot(AR_f4_1d, AR_f4_2d), np.ndarray[tuple[int], np.dtype[np.float32]])
assert_type(np.dot(AR_f4_1d, AR_f4_nd), Any)
assert_type(np.dot(AR_f4_2d, AR_f4_1d), np.ndarray[tuple[int], np.dtype[np.float32]])
assert_type(np.dot(AR_f4_2d, AR_f4_2d), np.ndarray[tuple[int, int], np.dtype[np.float32]])
assert_type(np.dot(AR_f4_2d, AR_f4_nd), Any)
assert_type(np.dot(AR_f4_nd, AR_f4_1d), Any)
assert_type(np.dot(AR_f4_nd, AR_f4_2d), Any)
assert_type(np.dot(AR_f4_nd, AR_f4_nd), Any)

assert_type(np.dot(AR_O_1d, AR_O_1d), Any)
assert_type(np.dot(AR_O_1d, AR_O_2d), np.ndarray[tuple[int], np.dtype[np.object_]])
assert_type(np.dot(AR_O_1d, AR_O_nd), Any)
assert_type(np.dot(AR_O_2d, AR_O_1d), np.ndarray[tuple[int], np.dtype[np.object_]])
assert_type(np.dot(AR_O_2d, AR_O_2d), np.ndarray[tuple[int, int], np.dtype[np.object_]])
assert_type(np.dot(AR_O_2d, AR_O_nd), Any)
assert_type(np.dot(AR_O_nd, AR_O_1d), Any)
assert_type(np.dot(AR_O_nd, AR_O_2d), Any)
assert_type(np.dot(AR_O_nd, AR_O_nd), Any)

#

type _Int1D = np.ndarray[tuple[int], np.dtype[np.intp]]

assert_type(np.where([True, True, False]), tuple[_Int1D,])
assert_type(np.where(AR_f4_1d), tuple[_Int1D])
assert_type(np.where(AR_f4_2d), tuple[_Int1D, _Int1D])
assert_type(np.where(AR_f4_3d), tuple[_Int1D, _Int1D, _Int1D])
assert_type(np.where(AR_f4_nd), tuple[_Int1D, ...])
assert_type(np.where([True, True, False], 1, 0), npt.NDArray[Any])

assert_type(np.lexsort([0, 1, 2]), npt.NDArray[np.intp])
Expand Down
7 changes: 6 additions & 1 deletion numpy/typing/tests/data/reveal/ndarray_misc.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,12 @@ assert_type(AR_f8.dot(1), npt.NDArray[Any])
assert_type(AR_f8.dot([1]), Any)
assert_type(AR_f8.dot(1, out=B), SubClass)

assert_type(AR_f8.nonzero(), tuple[np.ndarray[tuple[int], np.dtype[np.intp]], ...])
type _Int1D = np.ndarray[tuple[int], np.dtype[np.intp]]

assert_type(AR_f8.nonzero(), tuple[_Int1D, ...])
assert_type(AR_f8_1d.nonzero(), tuple[_Int1D])
assert_type(AR_f8_2d.nonzero(), tuple[_Int1D, _Int1D])
assert_type(AR_f8_3d.nonzero(), tuple[_Int1D, _Int1D, _Int1D])

assert_type(AR_f8.searchsorted(1), np.intp)
assert_type(AR_f8.searchsorted([1]), npt.NDArray[np.intp])
Expand Down
Loading