Skip to content

Commit 3c7207c

Browse files
committed
fix GetDataType.
1 parent 16d48a1 commit 3c7207c

18 files changed

Lines changed: 196 additions & 185 deletions

File tree

src/TensorFlowNET.Core/Binding.Util.cs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -526,8 +526,19 @@ public static TF_DataType GetDataType(this object data)
526526
var type = data.GetType();
527527
switch (data)
528528
{
529-
case Shape shape:
529+
case TensorShape:
530+
case Shape:
530531
return TF_DataType.TF_INT64;
532+
case Axis:
533+
return TF_DataType.TF_INT32;
534+
case NDArray nd:
535+
return nd.dtype;
536+
case Tensor tensor:
537+
return tensor.dtype;
538+
case Tensor[] tensor:
539+
return tensor[0].dtype;
540+
case ResourceVariable variable:
541+
return variable.dtype;
531542
default:
532543
return type.as_tf_dtype();
533544
}

src/TensorFlowNET.Core/Contexts/Context.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ public bool has_graph_arg(params object[] args)
142142
bool has_graph_arg = !tf.Context.executing_eagerly();
143143
foreach (var el in flatten_args)
144144
{
145-
if (el is Tensor tensor && !tensor.IsEagerTensor)
145+
if (el is Tensor tensor && tensor.IsCreatedInGraphMode)
146146
{
147147
has_graph_arg = true;
148148
break;

src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,6 @@ public EagerTensor(object value, Shape? shape = null, string device_name = null,
5050
public EagerTensor(Shape shape, TF_DataType dtype) : base(shape, dtype)
5151
=> NewEagerTensorHandle(_handle);
5252

53-
internal unsafe EagerTensor(string value) : base(value)
54-
=> NewEagerTensorHandle(_handle);
55-
5653
internal unsafe EagerTensor(Array array, Shape shape) : base(array, shape)
5754
=> NewEagerTensorHandle(_handle);
5855

src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ public void Record(Tensors flat_outputs, Tensors inference_args)
141141
src_graph: _func_graph);
142142

143143
var captures_from_forward = backwards_graph.external_captures
144-
.Where(x => !x.IsEagerTensor && x.graph == _func_graph)
144+
.Where(x => x.IsCreatedInGraphMode && x.graph == _func_graph)
145145
.ToArray();
146146
foreach(var capture in captures_from_forward)
147147
{

src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs

Lines changed: 39 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,47 @@ namespace Tensorflow.NumPy
88
{
99
public partial class NDArray
1010
{
11-
public NDArray(bool value) => _tensor = new EagerTensor(value);
12-
public NDArray(byte value) => _tensor = new EagerTensor(value);
13-
public NDArray(short value) => _tensor = new EagerTensor(value);
14-
public NDArray(int value) => _tensor = new EagerTensor(value);
15-
public NDArray(long value) => _tensor = new EagerTensor(value);
16-
public NDArray(float value) => _tensor = new EagerTensor(value);
17-
public NDArray(double value) => _tensor = new EagerTensor(value);
11+
public NDArray(bool value) => Init(value);
12+
public NDArray(byte value) => Init(value);
13+
public NDArray(short value) => Init(value);
14+
public NDArray(int value) => Init(value);
15+
public NDArray(long value) => Init(value);
16+
public NDArray(float value) => Init(value);
17+
public NDArray(double value) => Init(value);
18+
public NDArray(Array value, Shape? shape = null) => Init(value, shape);
19+
public NDArray(Shape shape, TF_DataType dtype = TF_DataType.TF_DOUBLE) => Init(shape, dtype: dtype);
20+
public NDArray(Tensor value, Shape? shape = null) => Init(value, shape);
1821

19-
public NDArray(Array value, Shape? shape = null) => _tensor = new EagerTensor(value, shape);
22+
public static NDArray Scalar<T>(T value) where T : unmanaged
23+
=> value switch
24+
{
25+
bool val => new NDArray(val),
26+
byte val => new NDArray(val),
27+
int val => new NDArray(val),
28+
float val => new NDArray(val),
29+
double val => new NDArray(val),
30+
_ => throw new NotImplementedException("")
31+
};
32+
33+
void Init<T>(T value) where T : unmanaged
34+
{
35+
_tensor = new EagerTensor(value);
36+
_tensor.SetReferencedByNDArray();
37+
}
38+
39+
void Init(Array value, Shape? shape = null)
40+
{
41+
_tensor = new EagerTensor(value, shape ?? value.GetShape());
42+
_tensor.SetReferencedByNDArray();
43+
}
2044

21-
public NDArray(Shape shape, TF_DataType dtype = TF_DataType.TF_DOUBLE)
22-
=> _tensor = new EagerTensor(shape, dtype: dtype);
45+
void Init(Shape shape, TF_DataType dtype = TF_DataType.TF_DOUBLE)
46+
{
47+
_tensor = new EagerTensor(shape, dtype: dtype);
48+
_tensor.SetReferencedByNDArray();
49+
}
2350

24-
public NDArray(Tensor value, Shape? shape = null)
51+
void Init(Tensor value, Shape? shape = null)
2552
{
2653
if (shape is not null)
2754
_tensor = tf.reshape(value, shape);
@@ -30,18 +57,8 @@ public NDArray(Tensor value, Shape? shape = null)
3057

3158
if (_tensor.TensorDataPointer == IntPtr.Zero)
3259
_tensor = tf.get_default_session().eval(_tensor);
33-
}
3460

35-
public static NDArray Scalar<T>(T value) where T : unmanaged
36-
{
37-
return value switch
38-
{
39-
bool val => new NDArray(val),
40-
int val => new NDArray(val),
41-
float val => new NDArray(val),
42-
double val => new NDArray(val),
43-
_ => throw new NotImplementedException("")
44-
};
61+
_tensor.SetReferencedByNDArray();
4562
}
4663
}
4764
}

src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs

Lines changed: 79 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ limitations under the License.
2121
using System.Numerics;
2222
using System.Text;
2323
using static Tensorflow.c_api;
24+
using static Tensorflow.Binding;
2425

2526
namespace Tensorflow
2627
{
@@ -31,7 +32,7 @@ public partial class Tensor
3132

3233
public Tensor()
3334
{
34-
35+
isCreatedInGraphMode = !tf.executing_eagerly();
3536
}
3637

3738
/// <summary>
@@ -41,60 +42,7 @@ public Tensor()
4142
public Tensor(IntPtr handle)
4243
{
4344
_handle = handle;
44-
//no need to set AllocationType = AllocationType.None;
45-
#if TRACK_TENSOR_LIFE
46-
print($"New Tensor 0x{_handle.ToString("x16")} {AllocationType} String Data: 0x{TensorDataPointer.ToString("x16")}");
47-
#endif
48-
}
49-
50-
unsafe internal Tensor(Shape shape, TF_DataType dtype)
51-
=> _handle = TF_NewTensor(shape, dtype, null);
52-
53-
internal Tensor(Array array, Shape? shape = null)
54-
=> InitTensor(array, shape);
55-
56-
unsafe void InitTensor(Array array, Shape? shape = null)
57-
{
58-
shape = shape ?? array.GetShape();
59-
var dtype = array.GetType().GetElementType().as_tf_dtype();
60-
61-
switch (array)
62-
{
63-
case bool[] val:
64-
fixed (void* addr = &val[0])
65-
_handle = TF_NewTensor(shape, dtype, addr);
66-
break;
67-
case int[] val:
68-
fixed (void* addr = &val[0])
69-
_handle = TF_NewTensor(shape, dtype, addr);
70-
break;
71-
case int[,] val:
72-
fixed (void* addr = &val[0, 0])
73-
_handle = TF_NewTensor(shape, dtype, addr);
74-
break;
75-
case long[] val:
76-
fixed (void* addr = &val[0])
77-
_handle = TF_NewTensor(shape, dtype, addr);
78-
break;
79-
case float[] val:
80-
fixed (void* addr = &val[0])
81-
_handle = TF_NewTensor(shape, dtype, addr);
82-
break;
83-
case float[,] val:
84-
fixed (void* addr = &val[0, 0])
85-
_handle = TF_NewTensor(shape, dtype, addr);
86-
break;
87-
case double[] val:
88-
fixed (void* addr = &val[0])
89-
_handle = TF_NewTensor(shape, dtype, addr);
90-
break;
91-
case double[,] val:
92-
fixed (void* addr = &val[0, 0])
93-
_handle = TF_NewTensor(shape, dtype, addr);
94-
break;
95-
default:
96-
throw new NotImplementedException("");
97-
}
45+
isCreatedInGraphMode = !tf.executing_eagerly();
9846
}
9947

10048
/// <summary>
@@ -109,22 +57,26 @@ unsafe void InitTensor(Array array, Shape? shape = null)
10957
public Tensor(IntPtr data_ptr, long[] shape, TF_DataType dType, int num_bytes)
11058
{
11159
_handle = TF_NewTensor(dType, dims: shape, num_dims: shape.Length, data: data_ptr, len: (ulong)num_bytes);
60+
isCreatedInGraphMode = !tf.executing_eagerly();
11261
}
11362

11463
public unsafe Tensor(NDArray nd)
115-
=> _handle = TF_NewTensor(nd.shape, nd.dtype, nd.data.ToPointer());
64+
{
65+
_handle = TF_NewTensor(nd.shape, nd.dtype, nd.data.ToPointer());
66+
isCreatedInGraphMode = !tf.executing_eagerly();
67+
}
11668

11769
#region scala
118-
public Tensor(bool value) => _handle = TF_NewTensor(value);
119-
public Tensor(byte value) => _handle = TF_NewTensor(value);
120-
public Tensor(sbyte value) => _handle = TF_NewTensor(value);
121-
public Tensor(short value) => _handle = TF_NewTensor(value);
122-
public Tensor(int value) => _handle = TF_NewTensor(value);
123-
public Tensor(uint value) => _handle = TF_NewTensor(value);
124-
public Tensor(long value) => _handle = TF_NewTensor(value);
125-
public Tensor(ulong value) => _handle = TF_NewTensor(value);
126-
public Tensor(float value) => _handle = TF_NewTensor(value);
127-
public Tensor(double value) => _handle = TF_NewTensor(value);
70+
public Tensor(bool value) => InitTensor(value);
71+
public Tensor(byte value) => InitTensor(value);
72+
public Tensor(sbyte value) => InitTensor(value);
73+
public Tensor(short value) => InitTensor(value);
74+
public Tensor(int value) => InitTensor(value);
75+
public Tensor(uint value) => InitTensor(value);
76+
public Tensor(long value) => InitTensor(value);
77+
public Tensor(ulong value) => InitTensor(value);
78+
public Tensor(float value) => InitTensor(value);
79+
public Tensor(double value) => InitTensor(value);
12880
#endregion
12981

13082
#region 1d array
@@ -142,31 +94,74 @@ public unsafe Tensor(NDArray nd)
14294
public Tensor(Complex[] data, Shape? shape = null) => InitTensor(data, shape);
14395
#endregion
14496

145-
/// <summary>
146-
/// Create a string Tensor from the given string
147-
/// </summary>
148-
public Tensor(string str)
97+
public Tensor(Operation op, int value_index, TF_DataType dtype)
98+
{
99+
_op = op;
100+
_value_index = value_index;
101+
_override_dtype = dtype;
102+
_id = ops.uid();
103+
isCreatedInGraphMode = !tf.executing_eagerly();
104+
}
105+
106+
internal Tensor(Shape shape, TF_DataType dtype) => InitTensor(shape, dtype);
107+
internal Tensor(Array array, Shape? shape = null) => InitTensor(array, shape);
108+
internal Tensor(string value) => InitTensor(value);
109+
110+
protected unsafe void InitTensor<T>(T data) where T : unmanaged
149111
{
150-
_handle = StringTensor(new string[] { str }, TensorShape.Scalar);
151-
#if TRACK_TENSOR_LIFE
152-
print($"New Tensor 0x{_handle.ToString("x16")} {AllocationType} String Data: 0x{TensorDataPointer.ToString("x16")}");
153-
#endif
112+
_handle = TF_NewTensor(data);
113+
isCreatedInGraphMode = !tf.executing_eagerly();
154114
}
155115

156-
public Tensor(string[] strings)
116+
protected unsafe void InitTensor(Shape shape, TF_DataType dtype)
157117
{
158-
_handle = StringTensor(strings, new TensorShape(strings.Length));
159-
#if TRACK_TENSOR_LIFE
160-
print($"New Tensor 0x{_handle.ToString("x16")} {AllocationType} String Data: 0x{TensorDataPointer.ToString("x16")}");
161-
#endif
118+
_handle = TF_NewTensor(shape, dtype, null);
119+
isCreatedInGraphMode = !tf.executing_eagerly();
162120
}
163121

164-
public Tensor(Operation op, int value_index, TF_DataType dtype)
122+
protected void InitTensor(string value)
165123
{
166-
_op = op;
167-
_value_index = value_index;
168-
_override_dtype = dtype;
169-
_id = ops.uid();
124+
_handle = StringTensor(new[] { value }, TensorShape.Scalar);
125+
isCreatedInGraphMode = !tf.executing_eagerly();
126+
}
127+
128+
protected unsafe void InitTensor(Array array, Shape? shape = null)
129+
{
130+
shape = shape ?? array.GetShape();
131+
var dtype = array.GetType().GetElementType().as_tf_dtype();
132+
133+
switch (array)
134+
{
135+
case bool[] val: fixed (void* addr = &val[0]) _handle = TF_NewTensor(shape, dtype, addr); break;
136+
case bool[,] val: fixed (void* addr = &val[0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break;
137+
case bool[,,] val: fixed (void* addr = &val[0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break;
138+
case bool[,,,] val: fixed (void* addr = &val[0, 0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break;
139+
case byte[] val: fixed (void* addr = &val[0]) _handle = TF_NewTensor(shape, dtype, addr); break;
140+
case byte[,] val: fixed (void* addr = &val[0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break;
141+
case byte[,,] val: fixed (void* addr = &val[0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break;
142+
case byte[,,,] val: fixed (void* addr = &val[0, 0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break;
143+
case int[] val: fixed (void* addr = &val[0]) _handle = TF_NewTensor(shape, dtype, addr); break;
144+
case int[,] val: fixed (void* addr = &val[0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break;
145+
case int[,,] val: fixed (void* addr = &val[0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break;
146+
case int[,,,] val: fixed (void* addr = &val[0, 0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break;
147+
case long[] val: fixed (void* addr = &val[0]) _handle = TF_NewTensor(shape, dtype, addr); break;
148+
case long[,] val: fixed (void* addr = &val[0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break;
149+
case long[,,] val: fixed (void* addr = &val[0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break;
150+
case long[,,,] val: fixed (void* addr = &val[0, 0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break;
151+
case float[] val: fixed (void* addr = &val[0]) _handle = TF_NewTensor(shape, dtype, addr); break;
152+
case float[,] val: fixed (void* addr = &val[0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break;
153+
case float[,,] val: fixed (void* addr = &val[0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break;
154+
case float[,,,] val: fixed (void* addr = &val[0, 0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break;
155+
case double[] val: fixed (void* addr = &val[0]) _handle = TF_NewTensor(shape, dtype, addr); break;
156+
case double[,] val: fixed (void* addr = &val[0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break;
157+
case double[,,] val: fixed (void* addr = &val[0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break;
158+
case double[,,,] val: fixed (void* addr = &val[0, 0, 0, 0]) _handle = TF_NewTensor(shape, dtype, addr); break;
159+
case string[] val: _handle = StringTensor(val, shape); break;
160+
default:
161+
throw new NotImplementedException("");
162+
}
163+
164+
isCreatedInGraphMode = !tf.executing_eagerly();
170165
}
171166
}
172167
}

src/TensorFlowNET.Core/Tensors/Tensor.String.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ public IntPtr StringTensor(string[] strings, TensorShape shape)
2323
public IntPtr StringTensor(byte[][] buffer, TensorShape shape)
2424
{
2525
var handle = c_api.TF_AllocateTensor(TF_DataType.TF_STRING,
26-
shape.ndim == 0 ? null : shape.dims.Select(x => (long)x).ToArray(),
26+
shape.ndim == 0 ? null : shape.dims,
2727
shape.ndim,
2828
(ulong)shape.size * TF_TSRING_SIZE);
2929

src/TensorFlowNET.Core/Tensors/Tensor.cs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,13 @@ public partial class Tensor : DisposableObject,
9393
/// TFE_TensorHandle
9494
/// </summary>
9595
public SafeTensorHandleHandle EagerTensorHandle { get; set; }
96-
protected bool _createdInGraphMode;
97-
public bool CreatedInGraphMode => _createdInGraphMode;
98-
public bool IsEagerTensor => this is EagerTensor;
96+
97+
protected bool isReferencedByNDArray;
98+
public bool IsReferencedByNDArray => isReferencedByNDArray;
99+
100+
protected bool isCreatedInGraphMode;
101+
102+
public bool IsCreatedInGraphMode => isCreatedInGraphMode;
99103
public bool IsSparseTensor => this is SparseTensor;
100104

101105
/// <summary>
@@ -207,6 +211,8 @@ public TF_Output _as_tf_output()
207211
return _tf_output.Value;
208212
}
209213

214+
public void SetReferencedByNDArray() => isReferencedByNDArray = true;
215+
210216
public Tensor MaybeMove()
211217
{
212218
var tensor = c_api.TF_TensorMaybeMove(_handle);

src/TensorFlowNET.Core/Tensors/TensorShape.Convert.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using Tensorflow.NumPy;
1+
using System.Linq;
2+
using Tensorflow.NumPy;
23

34
namespace Tensorflow
45
{
@@ -13,7 +14,7 @@ public void Deconstruct(out long h, out long w)
1314
public static implicit operator TensorShape(Shape shape) => new TensorShape((long[])shape.dims.Clone());
1415
public static implicit operator Shape(TensorShape shape) => shape == null ? null : new Shape((long[])shape.dims.Clone());
1516

16-
public static implicit operator int[](TensorShape shape) => shape == null ? null : (int[])shape.dims.Clone(); //we clone to avoid any changes
17+
public static implicit operator int[](TensorShape shape) => shape == null ? null : shape.dims.Select(x => (int)x).ToArray(); //we clone to avoid any changes
1718
public static implicit operator TensorShape(int[] dims) => dims == null ? null : new TensorShape(dims);
1819

1920
public static implicit operator long[](TensorShape shape) => shape == null ? null : (long[])shape.dims.Clone(); //we clone to avoid any changes

src/TensorFlowNET.Core/Tensors/Tensors.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ public class Tensors : IEnumerable<Tensor>, IDisposable
2121
public TensorShape shape => items.First().TensorShape;
2222
public int rank => items.First().rank;
2323
public Graph graph => items.First().graph;
24-
public bool IsEagerTensor => items.First().IsEagerTensor;
24+
public bool IsCreatedInGraphMode => items.First().IsCreatedInGraphMode;
2525
public bool IsList { get; set; }
2626
public int Length => items.Count();
2727

0 commit comments

Comments
 (0)