Skip to content

Commit 6989cb1

Browse files
committed
Add EagerRunner to fix memory leak and exception.
1 parent 56a9661 commit 6989cb1

63 files changed

Lines changed: 1968 additions & 2613 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

src/TensorFlowNET.Console/Program.cs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using System;
2+
using static Tensorflow.Binding;
23

34
namespace Tensorflow
45
{
@@ -14,18 +15,19 @@ static void Main(string[] args)
1415

1516
int batchSize = 1000;
1617

17-
// 1 million float tensor 58.5M.
18+
// 1 million float tensor 68M.
1819
mm.Execute(10, 100 * batchSize, cases.Constant);
1920

20-
// 100K float variable 80.5M.
21+
// 100K float variable 84M.
2122
mm.Execute(10, 10 * batchSize, cases.Variable);
2223

23-
// 1 million math add 36.5M.
24+
// 1 million math add 39M.
2425
mm.Execute(10, 100 * batchSize, cases.MathAdd);
2526

26-
// 100K gradient 210M.
27+
// 100K gradient 44M.
2728
mm.Execute(10, 10 * batchSize, cases.Gradient);
2829

30+
// 120M
2931
Console.WriteLine("Finished.");
3032
Console.ReadLine();
3133
}

src/TensorFlowNET.Core/APIs/tf.gradients.cs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@ namespace Tensorflow
2020
{
2121
public partial class tensorflow
2222
{
23-
public GradientTape GradientTape()
24-
=> new GradientTape();
23+
public GradientTape GradientTape(bool persistent = false,
24+
bool watch_accessed_variables = true)
25+
=> new GradientTape(persistent: persistent,
26+
watch_accessed_variables: watch_accessed_variables);
2527

2628
public Tensor[] gradients(Tensor[] ys,
2729
Tensor[] xs,

src/TensorFlowNET.Core/Eager/Context.cs

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,21 +18,30 @@ public Context(ContextOptions opts, Status status)
1818
status.Check(true);
1919
}
2020

21+
/// <summary>
22+
/// Initialize handle and devices if not already done so.
23+
/// </summary>
2124
public void ensure_initialized()
2225
{
2326
if (_initialized)
2427
return;
2528
_initialized = true;
2629
}
2730

31+
public void start_step()
32+
=> c_api.TFE_ContextStartStep(_handle);
33+
34+
public void end_step()
35+
=> c_api.TFE_ContextEndStep(_handle);
36+
2837
/// <summary>
2938
/// Dispose any unmanaged resources related to given <paramref name="handle"/>.
3039
/// </summary>
3140
protected sealed override void DisposeUnmanagedResources(IntPtr handle)
3241
=> c_api.TFE_DeleteContext(_handle);
3342

34-
35-
public bool executing_eagerly() => true;
43+
public bool executing_eagerly()
44+
=> default_execution_mode == EAGER_MODE;
3645

3746
public string shared_name(string name = null)
3847
=> !string.IsNullOrEmpty(name) || !executing_eagerly() ?

src/TensorFlowNET.Core/Eager/EagerOperation.cs

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,22 +8,19 @@ namespace Tensorflow.Eager
88
{
99
public class EagerOperation : Operation
1010
{
11-
static Dictionary<string, OpDef> op_dict;
1211
public string Name { get; set; }
1312
public new int NumInputs;
1413
public IntPtr[] InputHandles { get; set; }
1514
public Tensor[] Inputs { get; set; }
1615
public new int NumOutputs;
1716
public IntPtr[] OutputHandles { get; set; }
1817
public Tensor[] Outputs { get; set; }
19-
public BindingArray SkipInputIndicesArray { get; set; }
20-
public unsafe int[] SkipInputIndices => SkipInputIndicesArray.Data.Select(x => *(int*) x).ToArray();
21-
public string[] AttrsArray { get; set; }
18+
public long[] SkipInputIndices { get; set; }
19+
public object[] Attrs { get; set; }
2220

2321
public EagerOperation() : base(IntPtr.Zero)
2422
{
25-
if (op_dict == null)
26-
op_dict = op_def_registry.get_registered_ops();
23+
2724
}
2825

2926
public override InputList inputs
@@ -72,9 +69,9 @@ public override object get_attr(string attr_name)
7269

7370
public bool get_attr_bool(string attr_name)
7471
{
75-
for (int i = 0; i < AttrsArray.Length; i = i + 2)
76-
if (AttrsArray[i] == attr_name)
77-
return AttrsArray[i + 1] == "1";
72+
for (int i = 0; i < Attrs.Length; i = i + 2)
73+
if (Attrs[i].Equals(attr_name))
74+
return Attrs[i + 1].Equals("1");
7875

7976
throw new ValueError($"Can't find attr: {attr_name}");
8077
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using Tensorflow.Gradients;
5+
6+
namespace Tensorflow.Eager
7+
{
8+
public class EagerRunner : IEagerRunner
9+
{
10+
public Tensor[] TFE_Execute(Context ctx, string device_name, string op_name, Tensor[] inputs, object[] attrs, int num_outputs)
11+
{
12+
throw new NotImplementedException();
13+
}
14+
15+
public Tensor[] TFE_FastPathExecute(Context ctx, string device_name, string opName, string name, Action callbacks, params object[] args)
16+
{
17+
throw new NotImplementedException();
18+
}
19+
20+
public Tensor[] TFE_TapeGradient(ITape tape, Tensor[] target, Tensor[] sources, Tensor[] output_gradients)
21+
{
22+
throw new NotImplementedException();
23+
}
24+
}
25+
}

src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,6 @@ public EagerTensor Resolve()
4242
//print($"new Tensor {Id} {_handle.ToString("x16")}");
4343
//print($"new TensorHandle {Id} {EagerTensorHandle.ToString("x16")}");
4444

45-
/*GarbageCollector.Increase(_handle, GCItemType.TensorHandle);
46-
GarbageCollector.Increase(tfe_tensor_handle, GCItemType.LocalTensorHandle);*/
47-
4845
return this;
4946
}
5047

@@ -53,10 +50,6 @@ public override IntPtr ToPointer()
5350

5451
protected override void DisposeUnmanagedResources(IntPtr handle)
5552
{
56-
/*GarbageCollector.Decrease(_handle);
57-
GarbageCollector.Decrease(tfe_tensor_handle);
58-
GarbageCollector.Decrease(EagerTensorHandle);*/
59-
6053
//print($"deleting DeleteTensorHandle {Id} {_handle.ToString("x16")}");
6154
c_api.TF_DeleteTensor(_handle);
6255
//print($"deleting DeleteTensorHandle {Id} {EagerTensorHandle.ToString("x16")}");

src/TensorFlowNET.Core/Eager/Execute.cs

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using System.Collections.Generic;
22
using System;
33
using System.Linq;
4+
using static Tensorflow.Binding;
45

56
namespace Tensorflow.Eager
67
{
@@ -27,20 +28,18 @@ public class Execute
2728
/// <param name="ctx">The value of context.context().</param>
2829
/// <param name="name">Customized name for the operation.</param>
2930
/// <returns>List of output Tensor objects. The list is empty if there are no outputs</returns>
30-
public EagerTensor[] execute(Context ctx, string op_name, int num_outputs,
31-
EagerTensor[] inputs, object[] attrs,
31+
public Tensor[] execute(Context ctx, string op_name, int num_outputs,
32+
Tensor[] inputs, object[] attrs,
3233
string name = null)
3334
{
3435
ctx.ensure_initialized();
3536

36-
using var status = new Status();
37-
var results = wrap_tfe_src.TFE_Execute(ctx,
37+
var results = tf.Runner.TFE_Execute(ctx,
3838
ctx.device_name,
3939
op_name,
4040
inputs,
4141
attrs,
42-
num_outputs,
43-
status);
42+
num_outputs);
4443

4544
return results;
4645
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Eager
6+
{
7+
public class FastPathOpExecInfo
8+
{
9+
public Context ctx { get; set; }
10+
public string device_name { get; set; }
11+
public string op_name { get; set; }
12+
public string name { get; set; }
13+
public object[] args { get; set; }
14+
public bool run_gradient_callback { get; set; }
15+
public bool run_post_exec_callbacks { get; set; }
16+
public bool run_callbacks { get; set; }
17+
}
18+
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using Tensorflow.Gradients;
5+
6+
namespace Tensorflow.Eager
7+
{
8+
public interface IEagerRunner
9+
{
10+
public Tensor[] TFE_FastPathExecute(Context ctx,
11+
string device_name,
12+
string opName,
13+
string name,
14+
Action callbacks,
15+
params object[] args);
16+
17+
public Tensor[] TFE_Execute(Context ctx,
18+
string device_name,
19+
string op_name,
20+
Tensor[] inputs,
21+
object[] attrs,
22+
int num_outputs);
23+
24+
public Tensor[] TFE_TapeGradient(ITape tape,
25+
Tensor[] target,
26+
Tensor[] sources,
27+
Tensor[] output_gradients);
28+
}
29+
}

src/TensorFlowNET.Core/Eager/c_api.eager.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,12 @@ public delegate void delete_backward_function_callback(string op_name,
116116
[DllImport(TensorFlowLibName)]
117117
public static extern TFE_Context TFE_NewContext(IntPtr opts, IntPtr status);
118118

119+
[DllImport(TensorFlowLibName)]
120+
public static extern TFE_Context TFE_ContextStartStep(IntPtr ctx);
121+
122+
[DllImport(TensorFlowLibName)]
123+
public static extern TFE_Context TFE_ContextEndStep(IntPtr ctx);
124+
119125
/// <summary>
120126
///
121127
/// </summary>

0 commit comments

Comments
 (0)