Skip to content

Commit cf8aa9f

Browse files
authored
Merge pull request SciSharp#263 from teknologisk-institut/master
Added support for NDArray == NDArray and array_equal
2 parents 2a50a77 + 77d5742 commit cf8aa9f

10 files changed

Lines changed: 311 additions & 43 deletions

File tree

src/NumSharp.Core/APIs/np.load.cs

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using System;
1+
using System;
22
using System.Collections;
33
using System.Collections.Generic;
44
using System.IO;
@@ -10,9 +10,38 @@ namespace NumSharp
1010
{
1111
public static partial class np
1212
{
13-
#region NpyFormat
13+
#region NpyFormat
1414

15-
public static T Load<T>(byte[] bytes)
15+
//Signature from numpy doc:
16+
// numpy.load(file, mmap_mode=None, allow_pickle=True, fix_imports=True, encoding='ASCII')[source]
17+
public static NDArray load(string path)
18+
{
19+
using (var stream = new FileStream(path, FileMode.Open))
20+
return load(stream);
21+
}
22+
23+
public static NDArray load(Stream stream)
24+
{
25+
using (var reader = new BinaryReader(stream, System.Text.Encoding.ASCII
26+
#if !NET35 && !NET40
27+
, leaveOpen: true
28+
#endif
29+
))
30+
{
31+
int bytes;
32+
Type type;
33+
int[] shape;
34+
if (!parseReader(reader, out bytes, out type, out shape))
35+
throw new FormatException();
36+
37+
Array array = Array.CreateInstance(type, shape.Aggregate((dims, dim) => dims*dim));
38+
39+
var result = new NDArray(readValueMatrix(reader, array, bytes, type, shape));
40+
return result.reshape(shape);
41+
}
42+
}
43+
44+
public static T Load<T>(byte[] bytes)
1645
where T : class,
1746
#if !NETSTANDARD1_4
1847
ICloneable,
@@ -505,4 +534,4 @@ public static NpzDictionary<Array> LoadJagged_Npz(Stream stream, bool trim = tru
505534
}
506535
#endregion
507536
}
508-
}
537+
}

src/NumSharp.Core/APIs/np.logic.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,12 @@ public static NDArray<bool> isfinite(NDArray a)
102102
/// <returns>The result is returned as a boolean array.</returns>
103103
public static NDArray<bool> isnan(NDArray a)
104104
=> BackendFactory.GetEngine().IsNan(a);
105+
106+
107+
public static bool array_equal(NDArray a, NDArray b)
108+
{
109+
return a.array_equal(b);
110+
}
105111
}
106112
}
107113

src/NumSharp.Core/APIs/np.save.cs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@ public static partial class np
1212
{
1313
#region NpyFormat
1414

15+
public static void save(string filepath, Array arr)
16+
{
17+
Save(arr, filepath);
18+
}
19+
1520
public static byte[] Save(Array array)
1621
{
1722
using (var stream = new MemoryStream())
@@ -24,6 +29,10 @@ public static byte[] Save(Array array)
2429

2530
public static ulong Save(Array array, string path)
2631
{
32+
if (Path.GetExtension(path) != ".npy")
33+
{
34+
path += ".npy";
35+
}
2736
using (var stream = new FileStream(path, FileMode.Create))
2837
return Save(array, stream);
2938
}

src/NumSharp.Core/Backends/SIMD/SimdEngine.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ public class SimdEngine : DefaultEngine
99
{
1010
public override NDArray Add(NDArray x, NDArray y)
1111
{
12+
return base.Add(x, y);
1213
if (x.ndim == y.ndim && x.ndim == 1)
1314
{
1415
switch (Type.GetTypeCode(x.dtype))

src/NumSharp.Core/Creation/np.arange.cs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,38 @@ namespace NumSharp
99
{
1010
public static partial class np
1111
{
12+
public static NDArray arange(float stop)
13+
{
14+
return arange(0, stop, 1);
15+
}
16+
1217
public static NDArray arange(double stop)
1318
{
1419
return arange(0, stop, 1);
1520
}
1621

22+
public static NDArray arange(float start, float stop, float step = 1)
23+
{
24+
if (start > stop)
25+
{
26+
throw new Exception("parameters invalid, start is greater than stop.");
27+
}
28+
29+
int length = (int)Math.Ceiling((stop - start + 0.0) / step);
30+
31+
var nd = new NDArray(typeof(float), new Shape(length));
32+
33+
float[] puffer = nd.Array as float[];
34+
35+
for (int index = 0; index < length; index++)
36+
{
37+
float value = start + index * step;
38+
puffer[index] = value;
39+
}
40+
41+
return nd;
42+
}
43+
1744
public static NDArray arange(double start, double stop, double step = 1)
1845
{
1946
if (start > stop)

src/NumSharp.Core/NumSharp.Core.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'">
4949
<DefineConstants>DEBUG;TRACE</DefineConstants>
5050
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
51+
<NoWarn>1701;1702;IDE1006</NoWarn>
5152
</PropertyGroup>
5253

5354
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|AnyCPU'">

src/NumSharp.Core/Operations/Elementwise/NDArray.Equals.cs

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@ public override bool Equals(object obj)
4040

4141
public static NDArray<bool> operator ==(NDArray np, object obj)
4242
{
43+
if (obj is NDArray np2)
44+
{
45+
return np.equal(np2);
46+
}
4347
var boolTensor = new NDArray(typeof(bool),np.shape);
4448
bool[] bools = boolTensor.Storage.GetData() as bool[];
4549

@@ -113,5 +117,60 @@ public override bool Equals(object obj)
113117

114118
return boolTensor.MakeGeneric<bool>();
115119
}
120+
121+
/// NumPy signature: numpy.equal(x1, x2, /, out=None, *, where=True, casting='same_kind', order='K', dtype=None, subok=True[, signature, extobj]) = <ufunc 'equal'>
122+
/// <summary>
123+
/// Compare two NDArrays element wise
124+
/// </summary>
125+
/// <param name="np2">NDArray to compare with</param>
126+
/// <returns>NDArray with result of each element compare</returns>
127+
private NDArray<bool> equal(NDArray np2)
128+
{
129+
if (this.size != np2.size)
130+
{
131+
throw new ArgumentException("Different sized NDArray's in not yet supported by the equal operation", nameof(np2));
132+
}
133+
var boolTensor = new NDArray(typeof(bool), this.shape);
134+
bool[] bools = boolTensor.Storage.GetData() as bool[];
135+
136+
var values1 = this.Storage.GetData();
137+
var values2 = np2.Storage.GetData();
138+
for (int idx = 0; idx < bools.Length; idx++)
139+
{
140+
var v1 = values1.GetValue(idx);
141+
var v2 = values2.GetValue(idx);
142+
if (v1.Equals(v2))
143+
bools[idx] = true;
144+
}
145+
146+
return boolTensor.MakeGeneric<bool>();
147+
}
148+
149+
/// NumPy signature: numpy.array_equal(a1, a2)[source]
150+
/// <summary>
151+
/// Compares two NDArrays
152+
/// </summary>
153+
/// <param name="np2"></param>
154+
/// <returns>True if two arrays have the same shape and elements, False otherwise.</returns>
155+
public bool array_equal(NDArray np2)
156+
{
157+
if (!Enumerable.SequenceEqual(this.shape, np2.shape))
158+
{
159+
return false;
160+
}
161+
var values1 = this.Storage.GetData();
162+
var values2 = np2.Storage.GetData();
163+
for (int idx = 0; idx < values1.Length; idx++)
164+
{
165+
var v1 = values1.GetValue(idx);
166+
var v2 = values2.GetValue(idx);
167+
if (!v1.Equals(v2))
168+
return false;
169+
}
170+
171+
return true;
172+
}
173+
174+
116175
}
117176
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using NumSharp.Generic;
5+
6+
namespace NumSharp
7+
{
8+
public partial class NDArray
9+
{
10+
11+
public static NDArray<bool> operator !(NDArray np_)
12+
{
13+
var boolTensor = new NDArray(typeof(bool), np_.shape);
14+
bool[] bools = boolTensor.Storage.GetData<bool>();
15+
16+
bool[] np = np_.Storage.GetData<bool>();
17+
18+
for (int i = 0; i < bools.Length; i++)
19+
bools[i] = !np[i];
20+
21+
return boolTensor.MakeGeneric<bool>();
22+
}
23+
}
24+
}
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
using Microsoft.VisualStudio.TestTools.UnitTesting;
2+
using System;
3+
using System.Numerics;
4+
using System.Collections.Generic;
5+
using System.Text;
6+
using System.Linq;
7+
using NumSharp;
8+
using NumSharp.Generic;
9+
10+
namespace NumSharp.UnitTest.Operations
11+
{
12+
[TestClass]
13+
public class NDArrayEqualsTest
14+
{
15+
16+
[TestMethod]
17+
public void IntTwo1D_NDArrayEquals()
18+
{
19+
var np0 = new NDArray(new[] { 0, 0, 0, 0 }, new Shape(4));
20+
var np1 = new NDArray(new[] { 1, 2, 3, 4 }, new Shape(4));
21+
var np2 = new NDArray(new[] { 1, 2, 3, 4 }, new Shape(4));
22+
23+
var np3 = np1 == np2;
24+
Assert.IsTrue(Enumerable.SequenceEqual(new[] { true, true, true, true }, np3.Data<bool>()));
25+
var np3S = np.array_equal(np1, np2);
26+
Assert.IsTrue(np3S);
27+
28+
var np4 = np0 == np2;
29+
Assert.IsTrue(Enumerable.SequenceEqual(new[] { false, false, false, false }, np4.Data<bool>()));
30+
var np4S = np.array_equal(np0, np2);
31+
Assert.IsFalse(np4S);
32+
33+
34+
}
35+
36+
[TestMethod]
37+
public void IntAnd1D_NDArrayEquals()
38+
{
39+
var np1 = new NDArray(new[] { 1, 2, 3, 4 }, new Shape(4));
40+
41+
var np2 = np1 == 2;
42+
Assert.IsTrue(Enumerable.SequenceEqual(new[] { false, true, false, false }, np2.Data<bool>()));
43+
}
44+
45+
[TestMethod]
46+
public void IntTwo2D_NDArrayEquals()
47+
{
48+
var np1 = new NDArray(typeof(int), new Shape(2, 3));
49+
np1.SetData(new[] { 1, 2, 3, 4, 5, 6 });
50+
51+
var np2 = new NDArray(typeof(int), new Shape(2, 3));
52+
np2.SetData(new[] { 1, 2, 3, 4, 5, 6 });
53+
54+
var np3 = np1 == np2;
55+
56+
// expected
57+
var np3S = np.array_equal(np1, np2);
58+
Assert.IsTrue(np3S);
59+
var np4 = new bool[] { true, true, true, true, true, true };
60+
Assert.IsTrue(Enumerable.SequenceEqual(np3.Data<bool>(), np4));
61+
62+
63+
64+
var np5 = new NDArray(typeof(int), new Shape(2, 3));
65+
np5.SetData(new[] { 0, 0, 0, 0, 0, 0 });
66+
67+
var np6 = np1 == np5;
68+
// expected
69+
var np6S = np.array_equal(np1, np5);
70+
Assert.IsFalse(np6S);
71+
var np7 = new bool[] { false, false, false, false, false, false, };
72+
Assert.IsTrue(Enumerable.SequenceEqual(np6.Data<bool>(), np7));
73+
74+
}
75+
76+
77+
[TestMethod]
78+
public void IntAnd2D_NDArrayEquals()
79+
{
80+
var np1 = new NDArray(typeof(int), new Shape(2, 3));
81+
np1.SetData(new[] { 1, 2, 3, 4, 5, 6 });
82+
83+
var np2 = np1 == 2;
84+
85+
// expected
86+
var np3 = new bool[] { false, true, false, false, false, false };
87+
Assert.IsTrue(Enumerable.SequenceEqual(np2.Data<bool>(), np3));
88+
}
89+
}
90+
}

0 commit comments

Comments
 (0)