forked from SciSharp/NumSharp
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathNDArray.Indexing.cs
More file actions
114 lines (105 loc) · 3.35 KB
/
NDArray.Indexing.cs
File metadata and controls
114 lines (105 loc) · 3.35 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
namespace NumSharp.Core
{
public partial class NDArray<T>
{
/// <summary>
/// Index accessor
/// </summary>
/// <param name="select"></param>
/// <returns></returns>
public T this[params int[] select]
{
get
{
return Data[GetIndexInShape(select)];
}
set
{
Data[GetIndexInShape(select)] = value;
}
}
public NDArray<T> this[Shape select]
{
get
{
if (select.Length == NDim)
{
throw new Exception("Please use NDArray[m, n] to access element.");
}
else
{
int start = GetIndexInShape(select.Shapes.ToArray());
int length = Shape.DimOffset[select.Length - 1];
var n = new NDArray<T>();
Span<T> data = Data;
n.Data = data.Slice(start, length).ToArray();
int[] shape = new int[Shape.Length - select.Length];
for (int i = select.Length; i < Shape.Length; i++)
{
shape[i - select.Length] = Shape[i];
}
n.Shape = new Shape(shape);
// n.Shape = new Shape(Shape.Shapes.ToArray().AsSpan().Slice(select.Length).ToArray());
return n;
}
}
}
/// <summary>
/// Filter specific elements through select.
/// </summary>
/// <param name="select"></param>
/// <returns>Return a new NDArray with filterd elements.</returns>
public NDArray<T> this[IList<int> select]
{
get
{
var n = new NDArray<T>();
if (NDim == 1)
{
n.Data = new T[select.Count];
n.Shape = new Shape(select.Count);
for (int i = 0; i < select.Count; i++)
{
n[i] = this[select[i]];
}
}
else if (NDim == 2)
{
n.Data = new T[select.Count * Shape[1]];
n.Shape = new Shape(select.Count, Shape[1]);
for (int i = 0; i < select.Count; i++)
{
for (int j = 0; j < Shape[1]; j++)
{
n[i, j] = this[select[i], j];
}
}
}
else
{
throw new NotImplementedException();
}
return n;
}
}
/// <summary>
/// Overload
/// </summary>
/// <param name="select"></param>
/// <returns></returns>
public NDArray<T> this[NDArray<int> select] => this[select.Data.ToList()];
private int GetIndexInShape(params int[] select)
{
int idx = 0;
for (int i = 0; i < select.Length; i++)
{
idx += Shape.DimOffset[i] * select[i];
}
return idx;
}
}
}