forked from SciSharp/NumSharp
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathNdArray.Convolve.cs
More file actions
161 lines (137 loc) · 7.46 KB
/
NdArray.Convolve.cs
File metadata and controls
161 lines (137 loc) · 7.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
namespace NumSharp
{
public partial class NDArray
{
/// <summary>
/// Returns the discrete, linear convolution of two one-dimensional sequences.
///
/// The convolution operator is often seen in signal processing, where it models the effect of a linear time-invariant system on a signal[1]. In probability theory, the sum of two independent random variables is distributed according to the convolution of their individual distributions.
///
/// If v is longer than a, the arrays are swapped before computation.
/// </summary>
/// <param name="rhs"></param>
/// <param name="mode"></param>
/// <returns></returns>
public NDArray convolve(NDArray rhs, string mode = "full")
{
var lhs = this;
int nf = lhs.shape[0];
int ng = rhs.shape[0];
if (ndim > 1 || rhs.ndim > 1)
throw new IncorrectShapeException();
var retType = np._FindCommonType(lhs, rhs);
return null;
#if _REGEN
#region Output
%mod = "%"
switch (lhs.GetTypeCode)
{
%foreach supported_numericals,supported_numericals_lowercase%
case NPTypeCode.#1:
{
ArraySlice<#2> lhsarr = lhs.Storage.GetData<#2>();
switch (rhs.GetTypeCode)
{
%foreach supported_numericals,supported_numericals_lowercase%
case NPTypeCode.#101:
{
ArraySlice<#102> rhsarr = rhs.Storage.GetData<#102>();
%foreach supported_numericals,supported_numericals_lowercase%
switch (retType)
{
case NPTypeCode.#201:
{
#region Compute
switch (mode.ToLowerInvariant())
{
case "full":
{
int n = nf + ng - 1;
var ret = new NDArray<#201>(Shape.Vector(n), true);
var outArray = (ArraySlice<#202>)ret.Array;
for (int idx = 0; idx < n; ++idx)
{
int jmn = (idx >= ng - 1) ? (idx - (ng - 1)) : 0;
int jmx = (idx < nf - 1) ? idx : nf - 1;
for (int jdx = jmn; jdx <= jmx; ++jdx)
{
outArray[idx] += Convert.To#201(lhsarr[jdx] * rhsarr[idx - jdx]);
}
}
return ret;
}
case "valid":
{
var min_v = (nf < ng) ? lhsarr : rhsarr;
var max_v = (nf < ng) ? rhsarr : lhsarr;
int n = Math.Max(nf, ng) - Math.Min(nf, ng) + 1;
var ret = new NDArray(retType, Shape.Vector(n), true);
var outArray = (ArraySlice<#202>)ret.Array;
for (int idx = 0; idx < n; ++idx)
{
int kdx = idx;
for (int jdx = (min_v.Count - 1); jdx >= 0; --jdx)
{
outArray[idx] += Convert.To#202(min_v[jdx] * max_v[kdx]);
++kdx;
}
}
return ret;
}
case "same":
{
// followed the discussion on
// https://stackoverflow.com/questions/38194270/matlab-convolution-same-to-numpy-convolve
// implemented numpy convolve because we follow numpy
var npad = rhs.shape[0] - 1;
if (npad #(mod) 2 == 1)
{
unsafe
{
npad = (int)Math.Floor(((double)npad) / 2.0);
var arr = ArraySlice<#202>.Allocate(npad + lhsarr.Count);
lhsarr.CopyTo(arr.AsSpan, npad);
var retnd = new NDArray(new UnmanagedStorage(arr, Shape.Vector(lhsarr.Count)));
return retnd.convolve(rhs, "valid");
}
}
else
{
{
unsafe
{
npad = npad / 2;
var puffer = new NDArray(retType, Shape.Vector(npad + lhsarr.Count), true);
lhsarr.CopyTo(puffer.Storage.AsSpan<#202>(), npad);
var np1New = puffer;
puffer = new NDArray(retType, Shape.Vector(npad + np1New.size), true);
var cpylen = np1New.size * sizeof(#202);
Buffer.MemoryCopy(np1New.Address, ((#202*)puffer.Address) + npad, cpylen, cpylen);
return puffer.convolve(rhs, "valid");
}
}
}
}
default:
throw new ArgumentOutOfRangeException(nameof(mode));
}
#endregion
}
}
%
break;
}
%
}
break;
}
%
default:
throw new NotSupportedException();
}
#endregion
#else
#endif
}
}
}