Skip to content

Commit 9c2d5c4

Browse files
committed
fix NDArray creation in graph mode.
1 parent 2192f4d commit 9c2d5c4

12 files changed

Lines changed: 171 additions & 91 deletions

File tree

src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,42 @@ public EagerTensor(SafeTensorHandleHandle handle)
1414
Resolve();
1515
}
1616

17+
#region scalar eager tensor
18+
public EagerTensor(bool value) : base(value)
19+
=> NewEagerTensorHandle(_handle);
20+
public EagerTensor(byte value) : base(value)
21+
=> NewEagerTensorHandle(_handle);
22+
public EagerTensor(sbyte value) : base(value)
23+
=> NewEagerTensorHandle(_handle);
24+
public EagerTensor(short value) : base(value)
25+
=> NewEagerTensorHandle(_handle);
26+
public EagerTensor(int value) : base(value)
27+
=> NewEagerTensorHandle(_handle);
28+
public EagerTensor(uint value) : base(value)
29+
=> NewEagerTensorHandle(_handle);
30+
public EagerTensor(long value) : base(value)
31+
=> NewEagerTensorHandle(_handle);
32+
public EagerTensor(ulong value) : base(value)
33+
=> NewEagerTensorHandle(_handle);
34+
public EagerTensor(float value) : base(value)
35+
=> NewEagerTensorHandle(_handle);
36+
public EagerTensor(double value) : base(value)
37+
=> NewEagerTensorHandle(_handle);
38+
#endregion
39+
1740
public EagerTensor(object value,string device_name, TF_DataType dtype = TF_DataType.TF_UINT8) : base((float[])value)
1841
{
1942
throw new NotImplementedException("");
2043
}
2144

22-
public EagerTensor(object value, Shape shape = null, string device_name = null, TF_DataType dtype = TF_DataType.TF_UINT8) : base((float[])value)
45+
public EagerTensor(object value, Shape? shape = null, string device_name = null, TF_DataType dtype = TF_DataType.TF_UINT8) : base((float[])value)
2346
{
2447
NewEagerTensorHandle(_handle);
2548
}
2649

50+
public EagerTensor(Shape shape, TF_DataType dtype) : base(shape, dtype)
51+
=> NewEagerTensorHandle(_handle);
52+
2753
internal unsafe EagerTensor(string value) : base(value)
2854
=> NewEagerTensorHandle(_handle);
2955

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,55 @@
11
using System;
22
using System.Collections.Generic;
33
using System.Text;
4+
using Tensorflow.Eager;
45
using static Tensorflow.Binding;
56

67
namespace Tensorflow.NumPy
78
{
89
public partial class NDArray
910
{
11+
public NDArray(bool value) => _tensor = new EagerTensor(value);
12+
public NDArray(byte value) => _tensor = new EagerTensor(value);
13+
public NDArray(short value) => _tensor = new EagerTensor(value);
14+
public NDArray(int value) => _tensor = new EagerTensor(value);
15+
public NDArray(long value) => _tensor = new EagerTensor(value);
16+
public NDArray(float value) => _tensor = new EagerTensor(value);
17+
public NDArray(double value) => _tensor = new EagerTensor(value);
18+
19+
public NDArray(Array value, Shape? shape = null) => _tensor = new EagerTensor(value, shape);
20+
21+
public NDArray(Shape shape, NumpyDType dtype = NumpyDType.Float)
22+
{
23+
Initialize(shape, dtype: dtype);
24+
}
25+
26+
public NDArray(Tensor value, Shape? shape = null)
27+
{
28+
if (shape is not null)
29+
_tensor = tf.reshape(value, shape);
30+
else
31+
_tensor = value;
32+
33+
if (_tensor.TensorDataPointer == IntPtr.Zero)
34+
_tensor = tf.get_default_session().eval(_tensor);
35+
}
36+
37+
public static NDArray Scalar<T>(T value) where T : unmanaged
38+
{
39+
return value switch
40+
{
41+
bool val => new NDArray(val),
42+
int val => new NDArray(val),
43+
float val => new NDArray(val),
44+
double val => new NDArray(val),
45+
_ => throw new NotImplementedException("")
46+
};
47+
}
48+
1049
void Initialize(Shape shape, NumpyDType dtype = NumpyDType.Float)
1150
{
12-
_tensor = tf.zeros(shape, dtype: dtype.as_tf_dtype());
51+
// _tensor = tf.zeros(shape, dtype: dtype.as_tf_dtype());
52+
_tensor = new EagerTensor(shape, dtype: dtype.as_tf_dtype());
1353
}
1454
}
1555
}

src/TensorFlowNET.Core/Numpy/NDArray.cs

Lines changed: 0 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -17,63 +17,6 @@ public partial class NDArray
1717
public Shape shape => _tensor.shape;
1818
public IntPtr data => _tensor.TensorDataPointer;
1919

20-
public NDArray(bool value)
21-
{
22-
_tensor = ops.convert_to_tensor(value);
23-
}
24-
25-
public NDArray(byte value)
26-
{
27-
_tensor = ops.convert_to_tensor(value);
28-
}
29-
30-
public NDArray(int value)
31-
{
32-
_tensor = ops.convert_to_tensor(value);
33-
}
34-
35-
public NDArray(float value)
36-
{
37-
_tensor = ops.convert_to_tensor(value);
38-
}
39-
40-
public NDArray(double value)
41-
{
42-
_tensor = ops.convert_to_tensor(value);
43-
}
44-
45-
public NDArray(Array value, Shape shape = null)
46-
{
47-
_tensor = ops.convert_to_tensor(value);
48-
}
49-
50-
public NDArray(Type dtype, Shape shape)
51-
{
52-
53-
}
54-
55-
public NDArray(Shape shape, NumpyDType dtype = NumpyDType.Float)
56-
{
57-
Initialize(shape, dtype: dtype);
58-
}
59-
60-
public NDArray(Tensor value, Shape? shape = null)
61-
{
62-
if (shape is not null)
63-
_tensor = tf.reshape(value, shape);
64-
else
65-
_tensor = value;
66-
}
67-
68-
public static NDArray Scalar<T>(T value) where T : unmanaged
69-
{
70-
return value switch
71-
{
72-
bool b => new NDArray(b),
73-
_ => throw new NotImplementedException("")
74-
};
75-
}
76-
7720
public T GetValue<T>(int index) where T : unmanaged
7821
=> _tensor.ToArray<T>()[index];
7922
public T GetAtIndex<T>(int index) where T : unmanaged

src/TensorFlowNET.Core/Sessions/BaseSession.cs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,31 @@ private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair<TF_Output, Tensor>[] f
245245
return result;
246246
}
247247

248+
public unsafe Tensor eval(Tensor tensor)
249+
{
250+
var status = tf.Status;
251+
252+
var output_values = new IntPtr[1];
253+
var fetch_list = new[] { tensor._as_tf_output() };
254+
255+
c_api.TF_SessionRun(_handle,
256+
run_options: null,
257+
inputs: new TF_Output[0],
258+
input_values: new IntPtr[0],
259+
ninputs: 0,
260+
outputs: fetch_list,
261+
output_values: output_values,
262+
noutputs: 1,
263+
target_opers: new IntPtr[0],
264+
ntargets: 0,
265+
run_metadata: IntPtr.Zero,
266+
status: status.Handle);
267+
268+
status.Check(true);
269+
270+
return new Tensor(output_values[0]);
271+
}
272+
248273
private static unsafe NDArray fetchValue(IntPtr output)
249274
{
250275
var tensor = new Tensor(output);

src/TensorFlowNET.Core/Sessions/_FetchHandler.cs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,7 @@ public NDArray[] build_results(BaseSession session, NDArray[] tensor_values)
7878
{
7979
var value = tensor_values[j];
8080
j += 1;
81-
if (value.ndim == 0)
82-
full_values.Add(value);
83-
else
84-
full_values.Add(value[np.arange(0, (int)value.dims[0])]);
81+
full_values.Add(value);
8582
}
8683
i += 1;
8784
}

src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -64,48 +64,50 @@ public Tensor(IntPtr handle)
6464
#endif
6565
}
6666

67+
unsafe internal Tensor(Shape shape, TF_DataType dtype)
68+
=> _handle = TF_NewTensor(shape, dtype, null);
69+
6770
internal Tensor(Array array, Shape? shape = null)
6871
=> InitTensor(array, shape);
6972

7073
unsafe void InitTensor(Array array, Shape? shape = null)
7174
{
7275
shape = shape ?? array.GetShape();
7376
var dtype = array.GetType().GetElementType().as_tf_dtype();
74-
var length = (ulong)(array.Length * dtype.get_datatype_size());
7577

7678
switch (array)
7779
{
7880
case bool[] val:
7981
fixed (void* addr = &val[0])
80-
_handle = TF_NewTensor(shape, dtype, addr, length);
82+
_handle = TF_NewTensor(shape, dtype, addr);
8183
break;
8284
case int[] val:
8385
fixed (void* addr = &val[0])
84-
_handle = TF_NewTensor(shape, dtype, addr, length);
86+
_handle = TF_NewTensor(shape, dtype, addr);
8587
break;
8688
case int[,] val:
8789
fixed (void* addr = &val[0, 0])
88-
_handle = TF_NewTensor(shape, dtype, addr, length);
90+
_handle = TF_NewTensor(shape, dtype, addr);
8991
break;
9092
case long[] val:
9193
fixed (void* addr = &val[0])
92-
_handle = TF_NewTensor(shape, dtype, addr, length);
94+
_handle = TF_NewTensor(shape, dtype, addr);
9395
break;
9496
case float[] val:
9597
fixed (void* addr = &val[0])
96-
_handle = TF_NewTensor(shape, dtype, addr, length);
98+
_handle = TF_NewTensor(shape, dtype, addr);
9799
break;
98100
case float[,] val:
99101
fixed (void* addr = &val[0, 0])
100-
_handle = TF_NewTensor(shape, dtype, addr, length);
102+
_handle = TF_NewTensor(shape, dtype, addr);
101103
break;
102104
case double[] val:
103105
fixed (void* addr = &val[0])
104-
_handle = TF_NewTensor(shape, dtype, addr, length);
106+
_handle = TF_NewTensor(shape, dtype, addr);
105107
break;
106108
case double[,] val:
107109
fixed (void* addr = &val[0, 0])
108-
_handle = TF_NewTensor(shape, dtype, addr, length);
110+
_handle = TF_NewTensor(shape, dtype, addr);
109111
break;
110112
default:
111113
throw new NotImplementedException("");
@@ -131,7 +133,7 @@ public Tensor(IntPtr data_ptr, long[] shape, TF_DataType dType, int num_bytes)
131133
}
132134

133135
public unsafe Tensor(NDArray nd)
134-
=> _handle = TF_NewTensor(nd.shape, nd.dtype.as_tf_dtype(), nd.data.ToPointer(), nd.size * nd.dtypesize);
136+
=> _handle = TF_NewTensor(nd.shape, nd.dtype.as_tf_dtype(), nd.data.ToPointer());
135137

136138
#region scala
137139
public Tensor(bool value) => _handle = TF_NewTensor(value);

src/TensorFlowNET.Core/Tensors/TensorShape.Convert.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ public void Deconstruct(out long h, out long w)
1111
}
1212

1313
public static implicit operator TensorShape(Shape shape) => new TensorShape((long[])shape.dims.Clone());
14-
public static implicit operator Shape(TensorShape shape) => new Shape((long[])shape.dims.Clone());
14+
public static implicit operator Shape(TensorShape shape) => shape == null ? null : new Shape((long[])shape.dims.Clone());
1515

1616
public static implicit operator int[](TensorShape shape) => shape == null ? null : (int[])shape.dims.Clone(); //we clone to avoid any changes
1717
public static implicit operator TensorShape(int[] dims) => dims == null ? null : new TensorShape(dims);

src/TensorFlowNET.Core/Tensors/c_api.tensor.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,9 @@ public static unsafe IntPtr TF_NewTensor(TF_DataType dataType, long[] dims, int
104104
return c_api.TF_NewTensor(dataType, dims, num_dims, data, len, EmptyDeallocator, DeallocatorArgs.Empty);
105105
}
106106

107-
public static unsafe IntPtr TF_NewTensor(Shape shape, TF_DataType dtype, void* data, ulong length)
107+
public static unsafe IntPtr TF_NewTensor(Shape shape, TF_DataType dtype, void* data)
108108
{
109+
var length = shape.size * (ulong)dtype.get_datatype_size();
109110
var handle = TF_AllocateTensor(dtype, shape.dims, shape.ndim, length);
110111
var tensor = TF_TensorData(handle);
111112
System.Buffer.MemoryCopy(data, tensor.ToPointer(), length, length);

src/TensorFlowNET.Core/Tensors/constant_op.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ public static Tensor _constant_impl(object value,
9898
attrs: attrs,
9999
name: name);
100100

101+
var o = op.outputs;
101102
return op.outputs[0];
102103
}
103104

src/TensorFlowNET.Core/Tensors/dtypes.cs

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,58 @@ public static TF_DataType as_tf_dtype(this Type type, TF_DataType? dtype = null)
182182
return dtype.Value;
183183
}
184184

185+
public static TF_DataType tf_dtype_from_name(string name)
186+
{
187+
TF_DataType dtype = TF_DataType.DtInvalid;
188+
switch (name.ToLower())
189+
{
190+
case "char":
191+
dtype = TF_DataType.TF_UINT8;
192+
break;
193+
case "boolean":
194+
dtype = TF_DataType.TF_BOOL;
195+
break;
196+
case "sbyte":
197+
dtype = TF_DataType.TF_INT8;
198+
break;
199+
case "byte":
200+
dtype = TF_DataType.TF_UINT8;
201+
break;
202+
case "int16":
203+
dtype = TF_DataType.TF_INT16;
204+
break;
205+
case "uint16":
206+
dtype = TF_DataType.TF_UINT16;
207+
break;
208+
case "int32":
209+
dtype = TF_DataType.TF_INT32;
210+
break;
211+
case "uint32":
212+
dtype = TF_DataType.TF_UINT32;
213+
break;
214+
case "int64":
215+
dtype = TF_DataType.TF_INT64;
216+
break;
217+
case "uint64":
218+
dtype = TF_DataType.TF_UINT64;
219+
break;
220+
case "single":
221+
dtype = TF_DataType.TF_FLOAT;
222+
break;
223+
case "double":
224+
dtype = TF_DataType.TF_DOUBLE;
225+
break;
226+
case "complex":
227+
dtype = TF_DataType.TF_COMPLEX128;
228+
break;
229+
case "string":
230+
dtype = TF_DataType.TF_STRING;
231+
break;
232+
}
233+
234+
return dtype;
235+
}
236+
185237
public static DataType as_datatype_enum(this TF_DataType type)
186238
{
187239
return (DataType)type;

0 commit comments

Comments
 (0)