Skip to content

Commit bb9b9b8

Browse files
committed
Added np.argmax, np.argmin
1 parent 897bbc7 commit bb9b9b8

14 files changed

Lines changed: 1611 additions & 233 deletions

File tree

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
using NumSharp.Backends;
2+
using System;
3+
using System.Collections.Generic;
4+
using System.Text;
5+
6+
namespace NumSharp
7+
{
8+
public static partial class np
9+
{
10+
11+
/// <summary>
12+
/// Returns the indices of the maximum values along an axis.
13+
/// </summary>
14+
/// <param name="a">Input array.</param>
15+
/// <param name="axis">By default, the index is into the flattened array, otherwise along the specified axis.</param>
16+
/// <returns>Array of indices into the array. It has the same shape as a.shape with the dimension along axis removed.</returns>
17+
/// <remarks>https://docs.scipy.org/doc/numpy/reference/generated/numpy.argmax.html</remarks>
18+
public static NDArray argmax(NDArray a, int axis)
19+
=> a.TensorEngine.ArgMax(a, axis: axis);
20+
21+
/// <summary>
22+
/// Returns the index of the maximum value.
23+
/// </summary>
24+
/// <param name="a">Input array.</param>
25+
/// <returns>Array of indices into the array. It has the same shape as a.shape with the dimension along axis removed.</returns>
26+
/// <remarks>https://docs.scipy.org/doc/numpy/reference/generated/numpy.argmax.html</remarks>
27+
public static int argmax(NDArray a)
28+
=> a.TensorEngine.ArgMax(a);
29+
30+
/// <summary>
31+
/// Returns the indices of the minimum values along an axis.
32+
/// </summary>
33+
/// <param name="a">Input array.</param>
34+
/// <param name="axis">By default, the index is into the flattened array, otherwise along the specified axis.</param>
35+
/// <returns>Array of indices into the array. It has the same shape as a.shape with the dimension along axis removed.</returns>
36+
/// <remarks>https://docs.scipy.org/doc/numpy/reference/generated/numpy.argmin.html</remarks>
37+
public static NDArray argmin(NDArray a, int axis)
38+
=> a.TensorEngine.ArgMin(a, axis: axis);
39+
40+
/// <summary>
41+
/// Returns the index of the minimum value.
42+
/// </summary>
43+
/// <param name="a">Input array.</param>
44+
/// <returns>Array of indices into the array. It has the same shape as a.shape with the dimension along axis removed.</returns>
45+
/// <remarks>https://docs.scipy.org/doc/numpy/reference/generated/numpy.argmin.html</remarks>
46+
public static int argmin(NDArray a)
47+
=> a.TensorEngine.ArgMin(a);
48+
}
49+
}

src/NumSharp.Core/APIs/np.sorting_searching_counting.cs renamed to src/NumSharp.Core/APIs/np.argsort.cs

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,6 @@ namespace NumSharp
77
{
88
public static partial class np
99
{
10-
/// <summary>
11-
/// Returns the index of the maximum value of the array.
12-
/// </summary>
13-
public static NDArray argmax(NDArray nd, int axis = -1)
14-
=> BackendFactory.GetEngine().ArgMax(nd, axis: axis);
15-
16-
/// <summary>
17-
/// Returns the index of the maximum value of the array.
18-
/// </summary>
19-
public static int argmax<T>(NDArray nd)
20-
=> nd.argmax<T>();
21-
2210
/// <summary>
2311
/// Returns the indices that would sort an array.
2412
///

0 commit comments

Comments
 (0)