Skip to content

Commit 1d6144c

Browse files
authored
TYP: all and any shape-typing (#31170)
1 parent 6cb0da6 commit 1d6144c

4 files changed

Lines changed: 151 additions & 65 deletions

File tree

numpy/__init__.pyi

Lines changed: 44 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2355,77 +2355,97 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeT_co, _DTypeT_co]):
23552355
) -> ArrayT: ...
23562356

23572357
#
2358+
# keep in sync with `ndarray.any` (below)
23582359
@overload
23592360
def all(
23602361
self,
23612362
axis: None = None,
23622363
out: None = None,
2363-
keepdims: L[False, 0] = False,
2364+
keepdims: L[False] = False,
23642365
*,
23652366
where: _ArrayLikeBool_co = True
23662367
) -> bool_: ...
2367-
@overload
2368+
@overload # axis: <given>
2369+
def all(
2370+
self,
2371+
axis: int | tuple[int, ...],
2372+
out: None = None,
2373+
keepdims: L[False] = False,
2374+
*,
2375+
where: _ArrayLikeBool_co = True,
2376+
) -> NDArray[bool_]: ...
2377+
@overload # keepdims: True
23682378
def all(
23692379
self,
23702380
axis: int | tuple[int, ...] | None = None,
23712381
out: None = None,
2372-
keepdims: SupportsIndex = False,
23732382
*,
2383+
keepdims: L[True],
23742384
where: _ArrayLikeBool_co = True,
2375-
) -> bool_ | NDArray[bool_]: ...
2376-
@overload
2385+
) -> ndarray[_ShapeT_co, dtype[bool_]]: ...
2386+
@overload # out: <given> (keyword)
23772387
def all[ArrayT: ndarray](
23782388
self,
2379-
axis: int | tuple[int, ...] | None,
2380-
out: ArrayT,
2381-
keepdims: SupportsIndex = False,
2389+
axis: int | tuple[int, ...] | None = None,
23822390
*,
2391+
out: ArrayT,
2392+
keepdims: py_bool = False,
23832393
where: _ArrayLikeBool_co = True,
23842394
) -> ArrayT: ...
2385-
@overload
2395+
@overload # out: <given> (positional)
23862396
def all[ArrayT: ndarray](
23872397
self,
2388-
axis: int | tuple[int, ...] | None = None,
2389-
*,
2398+
axis: int | tuple[int, ...] | None,
23902399
out: ArrayT,
2391-
keepdims: SupportsIndex = False,
2400+
keepdims: py_bool = False,
2401+
*,
23922402
where: _ArrayLikeBool_co = True,
23932403
) -> ArrayT: ...
23942404

2405+
# keep in sync with `ndarray.all` (above)
23952406
@overload
23962407
def any(
23972408
self,
23982409
axis: None = None,
23992410
out: None = None,
2400-
keepdims: L[False, 0] = False,
2411+
keepdims: L[False] = False,
24012412
*,
24022413
where: _ArrayLikeBool_co = True
24032414
) -> bool_: ...
2404-
@overload
2415+
@overload # axis: <given>
2416+
def any(
2417+
self,
2418+
axis: int | tuple[int, ...],
2419+
out: None = None,
2420+
keepdims: L[False] = False,
2421+
*,
2422+
where: _ArrayLikeBool_co = True,
2423+
) -> NDArray[bool_]: ...
2424+
@overload # keepdims: True
24052425
def any(
24062426
self,
24072427
axis: int | tuple[int, ...] | None = None,
24082428
out: None = None,
2409-
keepdims: SupportsIndex = False,
24102429
*,
2430+
keepdims: L[True],
24112431
where: _ArrayLikeBool_co = True,
2412-
) -> bool_ | NDArray[bool_]: ...
2413-
@overload
2432+
) -> ndarray[_ShapeT_co, dtype[bool_]]: ...
2433+
@overload # out: <given> (keyword)
24142434
def any[ArrayT: ndarray](
24152435
self,
2416-
axis: int | tuple[int, ...] | None,
2417-
out: ArrayT,
2418-
keepdims: SupportsIndex = False,
2436+
axis: int | tuple[int, ...] | None = None,
24192437
*,
2438+
out: ArrayT,
2439+
keepdims: py_bool = False,
24202440
where: _ArrayLikeBool_co = True,
24212441
) -> ArrayT: ...
2422-
@overload
2442+
@overload # out: <given> (positional)
24232443
def any[ArrayT: ndarray](
24242444
self,
2425-
axis: int | tuple[int, ...] | None = None,
2426-
*,
2445+
axis: int | tuple[int, ...] | None,
24272446
out: ArrayT,
2428-
keepdims: SupportsIndex = False,
2447+
keepdims: py_bool = False,
2448+
*,
24292449
where: _ArrayLikeBool_co = True,
24302450
) -> ArrayT: ...
24312451

numpy/_core/fromnumeric.pyi

Lines changed: 83 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from _typeshed import Incomplete, SupportsBool
1+
from _typeshed import SupportsBool
22
from collections.abc import Sequence
33
from typing import (
44
Any,
@@ -882,35 +882,62 @@ def all(
882882
a: ArrayLike | None,
883883
axis: None = None,
884884
out: None = None,
885-
keepdims: Literal[False, 0] | _NoValueType = ...,
885+
keepdims: Literal[False] | _NoValueType = ...,
886886
*,
887887
where: _ArrayLikeBool_co | _NoValueType = ...,
888888
) -> np.bool: ...
889-
@overload
890-
def all(
891-
a: ArrayLike | None,
889+
@overload # axis: int
890+
def all[ShapeT: _Shape](
891+
a: ArrayLike,
892+
axis: int,
893+
out: None = None,
894+
keepdims: Literal[False] | _NoValueType = ...,
895+
*,
896+
where: _ArrayLikeBool_co | _NoValueType = ...,
897+
) -> NDArray[np.bool]: ...
898+
@overload # axis: (int, ...)
899+
def all[ShapeT: _Shape](
900+
a: ArrayLike,
901+
axis: tuple[int, ...],
902+
out: None = None,
903+
keepdims: Literal[False] | _NoValueType = ...,
904+
*,
905+
where: _ArrayLikeBool_co | _NoValueType = ...,
906+
) -> NDArray[np.bool] | Any: ...
907+
@overload # Nd, keepdims: True
908+
def all[ShapeT: _Shape](
909+
a: np.ndarray[ShapeT],
892910
axis: int | tuple[int, ...] | None = None,
893911
out: None = None,
894-
keepdims: _BoolLike_co | _NoValueType = ...,
895912
*,
913+
keepdims: Literal[True],
896914
where: _ArrayLikeBool_co | _NoValueType = ...,
897-
) -> Incomplete: ...
898-
@overload
915+
) -> np.ndarray[ShapeT, np.dtype[np.bool]]: ...
916+
@overload # ?d, keepdims: True
917+
def all[ShapeT: _Shape](
918+
a: ArrayLike,
919+
axis: int | tuple[int, ...] | None = None,
920+
out: None = None,
921+
*,
922+
keepdims: Literal[True],
923+
where: _ArrayLikeBool_co | _NoValueType = ...,
924+
) -> NDArray[np.bool]: ...
925+
@overload # out: <given> (keyword)
899926
def all[ArrayT: np.ndarray](
900927
a: ArrayLike | None,
901-
axis: int | tuple[int, ...] | None,
902-
out: ArrayT,
903-
keepdims: _BoolLike_co | _NoValueType = ...,
928+
axis: int | tuple[int, ...] | None = None,
904929
*,
930+
out: ArrayT,
931+
keepdims: bool | _NoValueType = ...,
905932
where: _ArrayLikeBool_co | _NoValueType = ...,
906933
) -> ArrayT: ...
907-
@overload
934+
@overload # out: <given> (positional)
908935
def all[ArrayT: np.ndarray](
909936
a: ArrayLike | None,
910-
axis: int | tuple[int, ...] | None = None,
911-
*,
937+
axis: int | tuple[int, ...] | None,
912938
out: ArrayT,
913-
keepdims: _BoolLike_co | _NoValueType = ...,
939+
keepdims: bool | _NoValueType = ...,
940+
*,
914941
where: _ArrayLikeBool_co | _NoValueType = ...,
915942
) -> ArrayT: ...
916943

@@ -920,35 +947,62 @@ def any(
920947
a: ArrayLike | None,
921948
axis: None = None,
922949
out: None = None,
923-
keepdims: Literal[False, 0] | _NoValueType = ...,
950+
keepdims: Literal[False] | _NoValueType = ...,
924951
*,
925952
where: _ArrayLikeBool_co | _NoValueType = ...,
926953
) -> np.bool: ...
927-
@overload
928-
def any(
929-
a: ArrayLike | None,
954+
@overload # axis: int
955+
def any[ShapeT: _Shape](
956+
a: ArrayLike,
957+
axis: int,
958+
out: None = None,
959+
keepdims: Literal[False] | _NoValueType = ...,
960+
*,
961+
where: _ArrayLikeBool_co | _NoValueType = ...,
962+
) -> NDArray[np.bool]: ...
963+
@overload # axis: (int, ...)
964+
def any[ShapeT: _Shape](
965+
a: ArrayLike,
966+
axis: tuple[int, ...],
967+
out: None = None,
968+
keepdims: Literal[False] | _NoValueType = ...,
969+
*,
970+
where: _ArrayLikeBool_co | _NoValueType = ...,
971+
) -> NDArray[np.bool] | Any: ...
972+
@overload # Nd, keepdims: True
973+
def any[ShapeT: _Shape](
974+
a: np.ndarray[ShapeT],
930975
axis: int | tuple[int, ...] | None = None,
931976
out: None = None,
932-
keepdims: _BoolLike_co | _NoValueType = ...,
933977
*,
978+
keepdims: Literal[True],
934979
where: _ArrayLikeBool_co | _NoValueType = ...,
935-
) -> Incomplete: ...
936-
@overload
980+
) -> np.ndarray[ShapeT, np.dtype[np.bool]]: ...
981+
@overload # ?d, keepdims: True
982+
def any[ShapeT: _Shape](
983+
a: ArrayLike,
984+
axis: int | tuple[int, ...] | None = None,
985+
out: None = None,
986+
*,
987+
keepdims: Literal[True],
988+
where: _ArrayLikeBool_co | _NoValueType = ...,
989+
) -> NDArray[np.bool]: ...
990+
@overload # out: <given> (keyword)
937991
def any[ArrayT: np.ndarray](
938992
a: ArrayLike | None,
939-
axis: int | tuple[int, ...] | None,
940-
out: ArrayT,
941-
keepdims: _BoolLike_co | _NoValueType = ...,
993+
axis: int | tuple[int, ...] | None = None,
942994
*,
995+
out: ArrayT,
996+
keepdims: bool | _NoValueType = ...,
943997
where: _ArrayLikeBool_co | _NoValueType = ...,
944998
) -> ArrayT: ...
945-
@overload
999+
@overload # out: <given> (positional)
9461000
def any[ArrayT: np.ndarray](
9471001
a: ArrayLike | None,
948-
axis: int | tuple[int, ...] | None = None,
949-
*,
1002+
axis: int | tuple[int, ...] | None,
9501003
out: ArrayT,
951-
keepdims: _BoolLike_co | _NoValueType = ...,
1004+
keepdims: bool | _NoValueType = ...,
1005+
*,
9521006
where: _ArrayLikeBool_co | _NoValueType = ...,
9531007
) -> ArrayT: ...
9541008

numpy/typing/tests/data/reveal/fromnumeric.pyi

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -212,21 +212,27 @@ assert_type(np.all(f4), np.bool)
212212
assert_type(np.all(f), np.bool)
213213
assert_type(np.all(AR_b), np.bool)
214214
assert_type(np.all(AR_f4), np.bool)
215-
assert_type(np.all(AR_b, axis=0), Any)
216-
assert_type(np.all(AR_f4, axis=0), Any)
217-
assert_type(np.all(AR_b, keepdims=True), Any)
218-
assert_type(np.all(AR_f4, keepdims=True), Any)
215+
assert_type(np.all(AR_b, axis=0), npt.NDArray[np.bool])
216+
assert_type(np.all(AR_f4, axis=0), npt.NDArray[np.bool])
217+
assert_type(np.all(AR_b, keepdims=True), npt.NDArray[np.bool])
218+
assert_type(np.all(AR_f4, keepdims=True), npt.NDArray[np.bool])
219+
assert_type(np.all(AR_f4_1d, keepdims=True), np.ndarray[tuple[int], np.dtype[np.bool]])
220+
assert_type(np.all(AR_f4_2d, keepdims=True), np.ndarray[tuple[int, int], np.dtype[np.bool]])
221+
assert_type(np.all(AR_f4_3d, keepdims=True), np.ndarray[tuple[int, int, int], np.dtype[np.bool]])
219222
assert_type(np.all(AR_f4, out=AR_subclass), NDArraySubclass)
220223

221224
assert_type(np.any(b), np.bool)
222225
assert_type(np.any(f4), np.bool)
223226
assert_type(np.any(f), np.bool)
224227
assert_type(np.any(AR_b), np.bool)
225228
assert_type(np.any(AR_f4), np.bool)
226-
assert_type(np.any(AR_b, axis=0), Any)
227-
assert_type(np.any(AR_f4, axis=0), Any)
228-
assert_type(np.any(AR_b, keepdims=True), Any)
229-
assert_type(np.any(AR_f4, keepdims=True), Any)
229+
assert_type(np.any(AR_b, axis=0), npt.NDArray[np.bool])
230+
assert_type(np.any(AR_f4, axis=0), npt.NDArray[np.bool])
231+
assert_type(np.any(AR_b, keepdims=True), npt.NDArray[np.bool])
232+
assert_type(np.any(AR_f4, keepdims=True), npt.NDArray[np.bool])
233+
assert_type(np.any(AR_f4_1d, keepdims=True), np.ndarray[tuple[int], np.dtype[np.bool]])
234+
assert_type(np.any(AR_f4_2d, keepdims=True), np.ndarray[tuple[int, int], np.dtype[np.bool]])
235+
assert_type(np.any(AR_f4_3d, keepdims=True), np.ndarray[tuple[int, int, int], np.dtype[np.bool]])
230236
assert_type(np.any(AR_f4, out=AR_subclass), NDArraySubclass)
231237

232238
assert_type(np.cumsum(b), np.ndarray[tuple[int], np.dtype[np.bool]])

numpy/typing/tests/data/reveal/ndarray_misc.pyi

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,14 +58,20 @@ assert_type(ctypes_obj.strides_as(ct.c_ubyte), ct.Array[ct.c_ubyte])
5858

5959
assert_type(f8.all(), np.bool)
6060
assert_type(AR_f8.all(), np.bool)
61-
assert_type(AR_f8.all(axis=0), np.bool | npt.NDArray[np.bool])
62-
assert_type(AR_f8.all(keepdims=True), np.bool | npt.NDArray[np.bool])
61+
assert_type(AR_f8.all(axis=0), npt.NDArray[np.bool])
62+
assert_type(AR_f8.all(keepdims=True), npt.NDArray[np.bool])
63+
assert_type(AR_f8_1d.all(keepdims=True), np.ndarray[tuple[int], np.dtype[np.bool]])
64+
assert_type(AR_f8_2d.all(keepdims=True), np.ndarray[tuple[int, int], np.dtype[np.bool]])
65+
assert_type(AR_f8_3d.all(keepdims=True), np.ndarray[tuple[int, int, int], np.dtype[np.bool]])
6366
assert_type(AR_f8.all(out=B), SubClass)
6467

6568
assert_type(f8.any(), np.bool)
6669
assert_type(AR_f8.any(), np.bool)
67-
assert_type(AR_f8.any(axis=0), np.bool | npt.NDArray[np.bool])
68-
assert_type(AR_f8.any(keepdims=True), np.bool | npt.NDArray[np.bool])
70+
assert_type(AR_f8.any(axis=0), npt.NDArray[np.bool])
71+
assert_type(AR_f8.any(keepdims=True), npt.NDArray[np.bool])
72+
assert_type(AR_f8_1d.any(keepdims=True), np.ndarray[tuple[int], np.dtype[np.bool]])
73+
assert_type(AR_f8_2d.any(keepdims=True), np.ndarray[tuple[int, int], np.dtype[np.bool]])
74+
assert_type(AR_f8_3d.any(keepdims=True), np.ndarray[tuple[int, int, int], np.dtype[np.bool]])
6975
assert_type(AR_f8.any(out=B), SubClass)
7076

7177
# same as below

0 commit comments

Comments
 (0)