Skip to content

Commit a947fd8

Browse files
committed
fix NDArray indexing.
1 parent da33a8b commit a947fd8

31 files changed

Lines changed: 186 additions & 352 deletions

File tree

src/TensorFlowNET.Console/MemoryBasicTest.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ public Action<int, int> Constant
2929
public Action<int, int> Constant2x3
3030
=> (epoch, iterate) =>
3131
{
32-
var nd = np.arange(1000).reshape(10, 100);
32+
var nd = np.arange(1000).reshape((10, 100));
3333
var tensor = tf.constant(nd);
3434
var data = tensor.numpy();
3535
};
@@ -51,14 +51,14 @@ public Action<int, int> ConstantString
5151
public Action<int, int> Variable
5252
=> (epoch, iterate) =>
5353
{
54-
var nd = np.arange(1 * 256 * 256 * 3).reshape(1, 256, 256, 3);
54+
var nd = np.arange(1 * 256 * 256 * 3).reshape((1, 256, 256, 3));
5555
ResourceVariable variable = tf.Variable(nd);
5656
};
5757

5858
public Action<int, int> VariableRead
5959
=> (epoch, iterate) =>
6060
{
61-
var nd = np.zeros(1 * 256 * 256 * 3).astype(np.float32).reshape(1, 256, 256, 3);
61+
var nd = np.zeros(1 * 256 * 256 * 3).astype(np.float32).reshape((1, 256, 256, 3));
6262
ResourceVariable variable = tf.Variable(nd);
6363

6464
for (int i = 0; i< 10; i++)

src/TensorFlowNET.Core/Data/MnistDataSet.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ public MnistDataSet(NDArray images, NDArray labels, Type dataType, bool reshape)
1717

1818
NumOfExamples = (int)images.dims[0];
1919

20-
images = images.reshape(images.dims[0], images.dims[1] * images.dims[2]);
20+
images = images.reshape((images.dims[0], images.dims[1] * images.dims[2]));
2121
images = images.astype(dataType);
2222
// for debug np.multiply performance
2323
var sw = new Stopwatch();

src/TensorFlowNET.Core/Data/MnistModelLoader.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ private NDArray ExtractImages(string file, int? limit = null)
124124
bytestream.Read(buf, 0, buf.Length);
125125

126126
var data = np.frombuffer(buf, np.@byte);
127-
data = data.reshape(num_images, rows, cols, 1);
127+
data = data.reshape((num_images, rows, cols, 1));
128128

129129
return data;
130130
}

src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ bool SetOpAttrScalar(Context ctx, SafeOpHandle op,
360360
c_api.TFE_OpSetAttrFloat(op, key, Convert.ToSingle(value));
361361
break;
362362
case TF_AttrType.TF_ATTR_SHAPE:
363-
var dims = (value as int[]).Select(x => (long)x).ToArray();
363+
var dims = (value as long[]).ToArray();
364364
c_api.TFE_OpSetAttrShape(op, key, dims, dims.Length, status.Handle);
365365
status.Check(true);
366366
break;

src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ public EagerTensor(object value, Shape shape = null, string device_name = null,
2424
NewEagerTensorHandle(_handle);
2525
}
2626

27+
internal unsafe EagerTensor(string value) : base(value)
28+
=> NewEagerTensorHandle(_handle);
29+
2730
internal unsafe EagerTensor(Array array, Shape shape) : base(array, shape)
2831
=> NewEagerTensorHandle(_handle);
2932

src/TensorFlowNET.Core/Framework/graph_util_impl.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,9 @@ private NodeDef create_const_op(string node_name, AttrValue dtype, NDArray data,
135135
output_node.Attr["dtype"] = dtype;
136136
output_node.Attr["value"] = new AttrValue()
137137
{
138-
Tensor = tensor_util.make_tensor_proto(
139-
data, dtype: dtype.Type.as_tf_dtype(), shape: data_shape)
138+
Tensor = tensor_util.make_tensor_proto(data,
139+
dtype: dtype.Type.as_tf_dtype(),
140+
shape: data_shape)
140141
};
141142

142143
return output_node;

src/TensorFlowNET.Core/NumPy/NDArray.Equal.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using System;
22
using System.Collections.Generic;
3+
using System.Linq;
34
using System.Text;
45
using static Tensorflow.Binding;
56

@@ -15,8 +16,17 @@ public override bool Equals(object obj)
1516
long val => GetAtIndex<long>(0) == val,
1617
float val => GetAtIndex<float>(0) == val,
1718
double val => GetAtIndex<double>(0) == val,
19+
NDArray val => Equals(this, val),
1820
_ => base.Equals(obj)
1921
};
2022
}
23+
24+
bool Equals(NDArray x, NDArray y)
25+
{
26+
if (x.ndim != y.ndim)
27+
return false;
28+
29+
return Enumerable.SequenceEqual(x.ToByteArray(), y.ToByteArray());
30+
}
2131
}
2232
}

src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ public static implicit operator byte[](NDArray nd)
1818
public static implicit operator int(NDArray nd)
1919
=> nd._tensor.ToArray<int>()[0];
2020

21+
public static implicit operator float(NDArray nd)
22+
=> nd._tensor.ToArray<float>()[0];
23+
2124
public static implicit operator double(NDArray nd)
2225
=> nd._tensor.ToArray<double>()[0];
2326

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using System.Text;
5+
using static Tensorflow.Binding;
6+
7+
namespace Tensorflow.NumPy
8+
{
9+
public partial class NDArray
10+
{
11+
public NDArray this[int index]
12+
{
13+
get
14+
{
15+
return _tensor[index];
16+
}
17+
18+
set
19+
{
20+
21+
}
22+
}
23+
24+
public NDArray this[params int[] index]
25+
{
26+
get
27+
{
28+
return _tensor[index.Select(x => new Slice(x, x + 1)).ToArray()];
29+
}
30+
31+
set
32+
{
33+
34+
}
35+
}
36+
37+
public NDArray this[params Slice[] slices]
38+
{
39+
get
40+
{
41+
return _tensor[slices];
42+
}
43+
44+
set
45+
{
46+
47+
}
48+
}
49+
50+
public NDArray this[NDArray mask]
51+
{
52+
get
53+
{
54+
throw new NotImplementedException("");
55+
}
56+
57+
set
58+
{
59+
60+
}
61+
}
62+
}
63+
}

src/TensorFlowNET.Core/Numpy/NDArray.cs

Lines changed: 8 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
using System.Collections.Generic;
33
using System.Linq;
44
using System.Text;
5+
using static Tensorflow.Binding;
56

67
namespace Tensorflow.NumPy
78
{
@@ -55,48 +56,12 @@ public NDArray(Shape shape, NumpyDType dtype = NumpyDType.Float)
5556
Initialize(shape, dtype: dtype);
5657
}
5758

58-
public NDArray(Tensor value)
59+
public NDArray(Tensor value, Shape? shape = null)
5960
{
60-
_tensor = value;
61-
}
62-
63-
public NDArray this[params int[] index]
64-
{
65-
get
66-
{
67-
throw new NotImplementedException("");
68-
}
69-
70-
set
71-
{
72-
73-
}
74-
}
75-
76-
public NDArray this[params Slice[] slices]
77-
{
78-
get
79-
{
80-
throw new NotImplementedException("");
81-
}
82-
83-
set
84-
{
85-
86-
}
87-
}
88-
89-
public NDArray this[NDArray mask]
90-
{
91-
get
92-
{
93-
throw new NotImplementedException("");
94-
}
95-
96-
set
97-
{
98-
99-
}
61+
if (shape is not null)
62+
_tensor = tf.reshape(value, shape);
63+
else
64+
_tensor = value;
10065
}
10166

10267
public static NDArray Scalar<T>(T value) where T : unmanaged
@@ -129,15 +94,14 @@ public NDIterator<T> AsIterator<T>(bool autoreset = false) where T : unmanaged
12994

13095
public bool HasNext() => throw new NotImplementedException("");
13196
public T MoveNext<T>() => throw new NotImplementedException("");
132-
public NDArray reshape(params int[] shape) => throw new NotImplementedException("");
133-
public NDArray reshape(params long[] shape) => throw new NotImplementedException("");
97+
public NDArray reshape(Shape newshape) => new NDArray(_tensor, newshape);
13498
public NDArray astype(Type type) => throw new NotImplementedException("");
13599
public NDArray astype(NumpyDType type) => throw new NotImplementedException("");
136100
public bool array_equal(NDArray rhs) => throw new NotImplementedException("");
137101
public NDArray ravel() => throw new NotImplementedException("");
138102
public void shuffle(NDArray nd) => throw new NotImplementedException("");
139103
public Array ToMuliDimArray<T>() => throw new NotImplementedException("");
140-
public byte[] ToByteArray() => _tensor.ToArray<byte>();
104+
public byte[] ToByteArray() => _tensor.BufferToArray();
141105
public static string[] AsStringArray(NDArray arr) => throw new NotImplementedException("");
142106

143107
public T[] Data<T>() where T : unmanaged

0 commit comments

Comments
 (0)