Skip to content

Commit f2cffb4

Browse files
authored
TYP: partition shape-typing (#31169)
1 parent 282d065 commit f2cffb4

2 files changed

Lines changed: 31 additions & 13 deletions

File tree

numpy/_core/fromnumeric.pyi

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -345,30 +345,46 @@ def matrix_transpose[ScalarT: np.generic](x: _ArrayLike[ScalarT], /) -> NDArray[
345345
def matrix_transpose(x: ArrayLike, /) -> NDArray[Any]: ...
346346

347347
#
348-
@overload
348+
@overload # Nd
349+
def partition[ArrayT: np.ndarray](
350+
a: ArrayT,
351+
kth: _ArrayLikeInt,
352+
axis: SupportsIndex = -1,
353+
kind: _PartitionKind = "introselect",
354+
order: str | Sequence[str] | None = None,
355+
) -> ArrayT: ...
356+
@overload # ?d
349357
def partition[ScalarT: np.generic](
350358
a: _ArrayLike[ScalarT],
351359
kth: _ArrayLikeInt,
352-
axis: SupportsIndex | None = -1,
360+
axis: SupportsIndex = -1,
353361
kind: _PartitionKind = "introselect",
354-
order: None = None,
362+
order: str | Sequence[str] | None = None,
355363
) -> NDArray[ScalarT]: ...
356-
@overload
357-
def partition(
358-
a: _ArrayLike[np.void],
364+
@overload # axis: None
365+
def partition[ScalarT: np.generic](
366+
a: _ArrayLike[ScalarT],
359367
kth: _ArrayLikeInt,
360-
axis: SupportsIndex | None = -1,
368+
axis: None,
361369
kind: _PartitionKind = "introselect",
362370
order: str | Sequence[str] | None = None,
363-
) -> NDArray[np.void]: ...
364-
@overload
371+
) -> _Array1D[ScalarT]: ...
372+
@overload # fallback
365373
def partition(
366374
a: ArrayLike,
367375
kth: _ArrayLikeInt,
368-
axis: SupportsIndex | None = -1,
376+
axis: SupportsIndex = -1,
369377
kind: _PartitionKind = "introselect",
370378
order: str | Sequence[str] | None = None,
371379
) -> NDArray[Any]: ...
380+
@overload # fallback, axis: None
381+
def partition(
382+
a: ArrayLike,
383+
kth: _ArrayLikeInt,
384+
axis: None,
385+
kind: _PartitionKind = "introselect",
386+
order: str | Sequence[str] | None = None,
387+
) -> _Array1D[Any]: ...
372388

373389
# keep roughly in sync with `ndarray.argpartition`
374390
@overload # axis: None

numpy/typing/tests/data/reveal/fromnumeric.pyi

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,13 @@ assert_type(np.transpose(AR_f4), npt.NDArray[np.float32])
8181
assert_type(np.transpose(AR_f4_1d), np.ndarray[tuple[int], np.dtype[np.float32]])
8282
assert_type(np.transpose(AR_f4_2d), np.ndarray[tuple[int, int], np.dtype[np.float32]])
8383

84-
assert_type(np.partition(b, 0, axis=None), npt.NDArray[np.bool])
85-
assert_type(np.partition(f4, 0, axis=None), npt.NDArray[np.float32])
86-
assert_type(np.partition(f, 0, axis=None), npt.NDArray[Any])
8784
assert_type(np.partition(AR_b, 0), npt.NDArray[np.bool])
85+
assert_type(np.partition(AR_b, 0, axis=None), np.ndarray[tuple[int], np.dtype[np.bool]])
8886
assert_type(np.partition(AR_f4, 0), npt.NDArray[np.float32])
87+
assert_type(np.partition(AR_f4, 0, axis=None), np.ndarray[tuple[int], np.dtype[np.float32]])
88+
assert_type(np.partition(AR_f4_1d, 0), np.ndarray[tuple[int], np.dtype[np.float32]])
89+
assert_type(np.partition(AR_f4_2d, 0), np.ndarray[tuple[int, int], np.dtype[np.float32]])
90+
assert_type(np.partition(AR_f4_3d, 0), np.ndarray[tuple[int, int, int], np.dtype[np.float32]])
8991

9092
assert_type(np.argpartition(b, 0), npt.NDArray[np.intp])
9193
assert_type(np.argpartition(f4, 0), npt.NDArray[np.intp])

0 commit comments

Comments
 (0)