-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy path_indexing_functions.py
More file actions
51 lines (40 loc) · 2.06 KB
/
_indexing_functions.py
File metadata and controls
51 lines (40 loc) · 2.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import arrayfire as af
from ._array_object import Array
def take(x: Array, indices: Array, /, *, axis: int | None = None) -> Array:
"""
Returns elements of an array along a specified axis using a set of indices.
This function extracts elements from the input array `x` at positions specified by the `indices` array. This
operation is similar to fancy indexing in NumPy, but it's limited to using one-dimensional arrays for indices.
The function allows selection along a specified axis. If no axis is specified and the input array is flat, it
behaves as if operating along the first axis.
Parameters
----------
x : Array
The input array from which to take elements.
indices : Array
A one-dimensional array of integer indices specifying which elements to extract.
axis : int | None, optional
The axis over which to select values. If the axis is negative, the selection is made from the last dimension.
For one-dimensional `x`, `axis` is optional; for multi-dimensional `x`, `axis` is required.
Returns
-------
Array
An array that contains the selected elements. This output array will have the same rank as `x`, but the size of
the dimension along the specified `axis` will correspond to the number of elements in `indices`.
Notes
-----
- The function mimics part of the behavior of advanced indexing in NumPy but is constrained by the current
specification that avoids __setitem__ and mutation of the array through indexing.
- If `axis` is None and `x` is multi-dimensional, an exception will be raised since an axis must be specified.
Raises
------
ValueError
If `axis` is None and `x` is multi-dimensional, or if `indices` is not a one-dimensional array.
"""
if axis is None:
flat_array = af.flat(x._array)
return Array._new(af.lookup(flat_array, indices._array))
if axis != 0:
shape = (x._array.size,)
afarray = af.moddims(x._array, shape)
return Array._new(af.lookup(afarray, indices._array, axis=axis))