Skip to content

Commit 71f2fc9

Browse files
committed
well structured
1 parent 9dd68b3 commit 71f2fc9

30 files changed

Lines changed: 603 additions & 652 deletions
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
using NumSharp.Backends;
2+
using System;
3+
using System.Collections.Generic;
4+
using System.Text;
5+
6+
namespace NumSharp
7+
{
8+
public partial class NDArray
9+
{
10+
public NDArray log()
11+
=> BackendFactory.GetEngine().Log(this);
12+
}
13+
}

src/NumSharp.Core/APIs/np.linear.algebra.cs renamed to src/NumSharp.Core/APIs/np.blas.cs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,13 @@ public static partial class np
1212
/// if both NDArrays are 1D, scalar product is returned independend of shape
1313
/// if both NDArrays are 2D matrix product is returned.
1414
/// </summary>
15-
/// <param name="a"></param>
16-
/// <param name="b"></param>
15+
/// <param name="x"></param>
16+
/// <param name="y"></param>
1717
/// <returns></returns>
18-
public static NDArray dot(NDArray a, NDArray b)
19-
=> BackendFactory.GetEngine().Dot(a, b);
18+
public static NDArray dot(NDArray x, NDArray y)
19+
=> BackendFactory.GetEngine().Dot(x, y);
20+
21+
public static NDArray matmul(NDArray x, NDArray y)
22+
=> BackendFactory.GetEngine().MatMul(x, y);
2023
}
2124
}

src/NumSharp.Core/APIs/np.math.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@ namespace NumSharp
77
{
88
public static partial class np
99
{
10-
public static int add(NDArray x, NDArray y)
10+
public static NDArray add(NDArray x, NDArray y)
1111
=> BackendFactory.GetEngine().Add(x, y);
12+
13+
public static NDArray log(NDArray x)
14+
=> BackendFactory.GetEngine().Log(x);
1215
}
1316
}
File renamed without changes.
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
using ArrayFire;
2+
using NumSharp.Interfaces;
3+
using System;
4+
using System.Collections.Generic;
5+
using System.Text;
6+
using Array = ArrayFire.Array;
7+
8+
namespace NumSharp.Backends
9+
{
10+
public class ArrayFireEngine : DefaultEngine
11+
{
12+
public override NDArray Add(NDArray x, NDArray y)
13+
{
14+
return base.Add(x, y);
15+
}
16+
}
17+
}

src/NumSharp.Core/Backends/ArrayFireEngine.cs

Lines changed: 0 additions & 22 deletions
This file was deleted.

src/NumSharp.Core/Backends/BackendFactory.cs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,11 @@ namespace NumSharp.Backends
77
{
88
public class BackendFactory
99
{
10-
public static ITensorEngine GetEngine(BackendType backendType = BackendType.Default)
10+
public static ITensorEngine GetEngine(BackendType backendType = BackendType.SIMD)
1111
{
1212
switch (backendType)
1313
{
14-
case BackendType.Default:
15-
return new DefaultEngine();
14+
case BackendType.MKL:
1615
case BackendType.SIMD:
1716
return new SimdEngine();
1817
case BackendType.ArrayFire:

src/NumSharp.Core/Backends/BackendType.cs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@ namespace NumSharp
66
{
77
public enum BackendType
88
{
9-
Default = 1,
9+
MKL = 1,
10+
11+
CUDA = 2,
1012

1113
/// <summary>
1214
/// Managed SIMD
1315
/// </summary>
14-
SIMD = 2,
16+
SIMD = 3,
1517

1618
ArrayFire = 4
1719
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace NumSharp.Backends
6+
{
7+
public class CudaEngine : DefaultEngine
8+
{
9+
}
10+
}
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
using NumSharp.Interfaces;
2+
using System;
3+
using System.Linq;
4+
using System.Runtime.InteropServices;
5+
using System.Threading.Tasks;
6+
7+
namespace NumSharp.Backends
8+
{
9+
public abstract partial class DefaultEngine
10+
{
11+
public NDArray Dot(NDArray x, NDArray y)
12+
{
13+
var dtype = x.dtype;
14+
15+
if (x.ndim == 0 && y.ndim == 0)
16+
{
17+
switch (dtype.Name)
18+
{
19+
case "Int32":
20+
return y.Data<int>(0) * x.Data<int>(0);
21+
case "Single":
22+
return y.Data<float>(0) * x.Data<float>(0);
23+
}
24+
}
25+
else if (x.ndim == 1 && x.ndim == 1)
26+
{
27+
28+
switch (dtype.Name)
29+
{
30+
case "Int32":
31+
{
32+
int sum = 0;
33+
for (int i = 0; i < x.size; i++)
34+
sum += x.Data<int>(i) * y.Data<int>(i);
35+
return sum;
36+
}
37+
38+
case "Single":
39+
{
40+
float sum = 0;
41+
for (int i = 0; i < x.size; i++)
42+
sum += x.Data<float>(i) * y.Data<float>(i);
43+
return sum;
44+
}
45+
}
46+
}
47+
else if (x.ndim == 2 && y.ndim == 1)
48+
{
49+
// check size
50+
if (x.shape[1] != y.shape[0])
51+
throw new IncorrectSizeException($"shapes ({x.shape[0]},{x.shape[1]}) and ({y.shape[0]},) not aligned: {x.shape[1]} (dim 1) != {y.shape[0]} (dim 0)");
52+
var nd = new NDArray(dtype, new Shape(x.shape[0]));
53+
switch (dtype.Name)
54+
{
55+
case "Int32":
56+
for (int i = 0; i < x.shape[0]; i++)
57+
for (int j = 0; j < y.shape[0]; j++)
58+
nd.Data<int>()[i] += x.Data<int>(i, j) * y.Data<int>(j);
59+
break;
60+
case "Single":
61+
for (int i = 0; i < x.shape[0]; i++)
62+
for (int j = 0; j < y.shape[0]; j++)
63+
nd.Data<float>()[i] += x.Data<float>(i, j) * y.Data<float>(j);
64+
break;
65+
}
66+
return nd;
67+
}
68+
else if (x.ndim == 2 && y.ndim == 2)
69+
{
70+
// check size
71+
if (x.shape[1] != y.shape[0])
72+
throw new IncorrectSizeException($"shapes ({x.shape[0]},{x.shape[1]}) and ({y.shape[0]},{y.shape[1]}) not aligned: {x.shape[1]} (dim 1) != {y.shape[0]} (dim 0)");
73+
return np.matmul(x, y);
74+
}
75+
76+
throw new NotImplementedException($"dot {x.ndim} * {y.ndim}");
77+
}
78+
}
79+
}

0 commit comments

Comments
 (0)