Skip to content

Commit efb1c24

Browse files
committed
fix constant_value when referenced by ndarray.
1 parent 39883ae commit efb1c24

7 files changed

Lines changed: 18 additions & 45 deletions

File tree

src/TensorFlowNET.Core/DisposableObject.cs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,12 @@ public abstract class DisposableObject : IDisposable
2929
protected IntPtr _handle;
3030
protected bool _disposed;
3131

32-
[SuppressMessage("ReSharper", "UnusedMember.Global")]
3332
protected DisposableObject()
3433
{ }
3534

3635
protected DisposableObject(IntPtr handle)
3736
=> _handle = handle;
3837

39-
[SuppressMessage("ReSharper", "InvertIf")]
4038
private void Dispose(bool disposing)
4139
{
4240
if (_disposed)

src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -94,11 +94,5 @@ void copy_handle_data(Tensor target_t)
9494
// c_api.TF_GraphSetOutputHandleShapesAndTypes(target_t.graph, target_t._as_tf_output(), 0, new IntPtr[0], new int[0], new DataType[0], tf.Status.Handle);
9595
}
9696
}
97-
98-
protected override void DisposeUnmanagedResources(IntPtr handle)
99-
{
100-
base.DisposeUnmanagedResources(handle);
101-
_eagerTensorHandle.Dispose();
102-
}
10397
}
10498
}

src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -43,44 +43,30 @@ void Init<T>(T value) where T : unmanaged
4343
double val => new Tensor(val),
4444
_ => throw new NotImplementedException("")
4545
};
46-
_tensor.SetReferencedByNDArray();
4746

48-
var _handle = c_api.TFE_NewTensorHandle(_tensor, tf.Status.Handle);
49-
_tensor.SetEagerTensorHandle(_handle);
47+
_tensor.SetReferencedByNDArray();
5048
}
5149

5250
void Init(Array value, Shape? shape = null)
5351
{
5452
_tensor = new Tensor(value, shape ?? value.GetShape());
5553
_tensor.SetReferencedByNDArray();
56-
57-
var _handle = c_api.TFE_NewTensorHandle(_tensor, tf.Status.Handle);
58-
_tensor.SetEagerTensorHandle(_handle);
5954
}
6055

6156
void Init(Shape shape, TF_DataType dtype = TF_DataType.TF_DOUBLE)
6257
{
6358
_tensor = new Tensor(shape, dtype: dtype);
6459
_tensor.SetReferencedByNDArray();
65-
66-
var _handle = c_api.TFE_NewTensorHandle(_tensor, tf.Status.Handle);
67-
_tensor.SetEagerTensorHandle(_handle);
6860
}
6961

7062
void Init(Tensor value, Shape? shape = null)
7163
{
72-
if (shape is not null)
73-
_tensor = new Tensor(value.TensorDataPointer, shape, value.dtype);
74-
else
75-
_tensor = value;
76-
77-
if (_tensor.TensorDataPointer == IntPtr.Zero)
78-
_tensor = tf.get_default_session().eval(_tensor);
64+
// created tensor in graph mode
65+
if (value.TensorDataPointer == IntPtr.Zero)
66+
value = tf.defaultSession.eval(value);
7967

68+
_tensor = new Tensor(value.TensorDataPointer, shape ?? value.shape, value.dtype);
8069
_tensor.SetReferencedByNDArray();
81-
82-
var _handle = c_api.TFE_NewTensorHandle(_tensor, tf.Status.Handle);
83-
_tensor.SetEagerTensorHandle(_handle);
8470
}
8571
}
8672
}

src/TensorFlowNET.Core/Numpy/NDArray.cs

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
namespace Tensorflow.NumPy
88
{
9-
public partial class NDArray : DisposableObject
9+
public partial class NDArray
1010
{
1111
Tensor _tensor;
1212
public TF_DataType dtype => _tensor.dtype;
@@ -58,11 +58,5 @@ public override string ToString()
5858
{
5959
return tensor_util.to_numpy_string(_tensor);
6060
}
61-
62-
protected override void DisposeUnmanagedResources(IntPtr handle)
63-
{
64-
_tensor.EagerTensorHandle.Dispose();
65-
_tensor.Dispose();
66-
}
6761
}
6862
}

src/TensorFlowNET.Core/Tensors/Tensor.cs

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -212,9 +212,12 @@ public TF_Output _as_tf_output()
212212
return _tf_output.Value;
213213
}
214214

215-
public void SetReferencedByNDArray() => isReferencedByNDArray = true;
216-
public void SetEagerTensorHandle(SafeTensorHandleHandle handle) => _eagerTensorHandle = handle;
217-
215+
public void SetReferencedByNDArray()
216+
{
217+
isReferencedByNDArray = true;
218+
_eagerTensorHandle = c_api.TFE_NewTensorHandle(_handle, tf.Status.Handle);
219+
}
220+
218221
public Tensor MaybeMove()
219222
{
220223
var tensor = c_api.TF_TensorMaybeMove(_handle);
@@ -256,7 +259,6 @@ public override string ToString()
256259
}
257260
}
258261

259-
[SuppressMessage("ReSharper", "ConvertIfStatementToSwitchStatement")]
260262
protected override void DisposeUnmanagedResources(IntPtr handle)
261263
{
262264
if (dtype == TF_DataType.TF_STRING)
@@ -274,6 +276,9 @@ protected override void DisposeUnmanagedResources(IntPtr handle)
274276
}
275277

276278
c_api.TF_DeleteTensor(handle);
279+
280+
if (_eagerTensorHandle is not null)
281+
_eagerTensorHandle.Dispose();
277282
}
278283

279284
public bool IsDisposed => _disposed;

src/TensorFlowNET.Core/Tensors/tensor_util.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@ public static class tensor_util
3535
/// <returns></returns>
3636
public static NDArray constant_value(Tensor tensor, bool partial = false)
3737
{
38-
if (tensor is EagerTensor)
38+
if (tensor.IsReferencedByNDArray)
39+
return new NDArray(tensor);
40+
else if (tensor is EagerTensor)
3941
return tensor.numpy();
4042

4143
NDArray ret = _ConstantValue(tensor, partial);

test/TensorFlowNET.Graph.UnitTest/GraphModeTestBase.cs

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,5 @@ public void TestInit()
1010
{
1111
tf.compat.v1.disable_eager_execution();
1212
}
13-
14-
[TestCleanup]
15-
public void TestClean()
16-
{
17-
tf.enable_eager_execution();
18-
}
1913
}
2014
}

0 commit comments

Comments
 (0)