Skip to content

Commit 7764865

Browse files
committed
eager Tape test
1 parent e93f112 commit 7764865

13 files changed

Lines changed: 86 additions & 17 deletions

File tree

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
[![Badge](https://img.shields.io/badge/link-996.icu-red.svg)](https://996.icu/#/en_US)
1010
[![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/javiercp/BinderTF.NET/master?urlpath=lab)
1111

12-
*master branch is based on tensorflow 2.1 now, v0.15-tensorflow1.15 is from tensorflow1.15.*
12+
*master branch is based on tensorflow 2.2 now, v0.15-tensorflow1.15 is from tensorflow1.15.*
1313

1414
TF.NET is a member project of [SciSharp STACK](https://github.com/SciSharp).
1515

@@ -28,7 +28,7 @@ In comparison to other projects, like for instance TensorFlowSharp which only pr
2828

2929
### How to use
3030

31-
| TensorFlow | tf 1.13 | tf 1.14 | tf 1.15 | tf 2.0 |
31+
| TensorFlow | tf 1.13 | tf 1.14 | tf 1.15 | tf 2.2 |
3232
| ----------- | ------- | ------- | ------- | ------ |
3333
| tf.net 0.20 | | | x | x |
3434
| tf.net 0.15 | | x | x | |

docs/assets/tf2.jpg

90 KB
Loading

docs/assets/tf2.psd

386 KB
Binary file not shown.

src/TensorFlowNET.Core/Eager/EagerTensor.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,31 +13,37 @@ public EagerTensor(IntPtr handle) : base(handle)
1313
{
1414
tfe_tensor_handle = handle;
1515
_handle = c_api.TFE_TensorHandleResolve(handle, status);
16+
_id = ops.uid();
1617
}
1718

1819
public EagerTensor(string value, string device_name) : base(value)
1920
{
2021
tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status);
22+
_id = ops.uid();
2123
}
2224

2325
public EagerTensor(int value, string device_name) : base(value)
2426
{
2527
tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status);
28+
_id = ops.uid();
2629
}
2730

2831
public EagerTensor(float[] value, string device_name) : base(value)
2932
{
3033
tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status);
34+
_id = ops.uid();
3135
}
3236

3337
public EagerTensor(double[] value, string device_name) : base(value)
3438
{
3539
tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status);
40+
_id = ops.uid();
3641
}
3742

3843
public EagerTensor(NDArray value, string device_name) : base(value)
3944
{
4045
tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status);
46+
_id = ops.uid();
4147
}
4248

4349
public override string ToString()

src/TensorFlowNET.Core/Eager/c_api.eager.cs

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -102,14 +102,20 @@ public partial class c_api
102102
public static extern TFE_Op TFE_NewOp(IntPtr ctx, string op_or_function_name, IntPtr status);
103103

104104
/// <summary>
105-
///
105+
/// Resets `op_to_reset` with `op_or_function_name` and `raw_device_name`. This
106+
/// is for performance optimization by reusing an exiting unused op rather than
107+
/// creating a new op every time. If `raw_device_name` is `NULL` or empty, it
108+
/// does not set the device name. If it's not `NULL`, then it attempts to parse
109+
/// and set the device name. It's effectively `TFE_OpSetDevice`, but it is faster
110+
/// than separately calling it because if the existing op has the same
111+
/// `raw_device_name`, it skips parsing and just leave as it is.
106112
/// </summary>
107-
/// <param name="ctx">TFE_Context*</param>
113+
/// <param name="op_to_reset">TFE_Op*</param>
108114
/// <param name="op_or_function_name">const char*</param>
115+
/// <param name="raw_device_name">const char*</param>
109116
/// <param name="status">TF_Status*</param>
110-
/// <param name="op_to_reset">TFE_Op*</param>
111117
[DllImport(TensorFlowLibName)]
112-
public static extern void TFE_OpReset(IntPtr ctx, string op_or_function_name, IntPtr status, IntPtr op_to_reset);
118+
public static extern void TFE_OpReset(IntPtr op_to_reset, string op_or_function_name, string raw_device_name, IntPtr status);
113119

114120
/// <summary>
115121
///
@@ -304,5 +310,17 @@ public partial class c_api
304310
/// <returns>TFE_Executor*</returns>
305311
[DllImport(TensorFlowLibName)]
306312
public static extern TFE_Executor TFE_ContextGetExecutorForThread(IntPtr ctx);
313+
314+
[DllImport(TensorFlowLibName)]
315+
public static extern void TFE_Test();
316+
317+
[DllImport(TensorFlowLibName)]
318+
public static extern IntPtr TFE_TapeSetNew(bool persistent, bool watch_accessed_variables);
319+
320+
[DllImport(TensorFlowLibName)]
321+
public static extern void TFE_TapeWatch(IntPtr tape, IntPtr tensor, int tensor_id);
322+
323+
[DllImport(TensorFlowLibName)]
324+
public static extern void TFE_TapeGradient(IntPtr tape, long[] targetTensorIds, IntPtr[] target, long[] sourcesTensorIds, IntPtr status);
307325
}
308326
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
using System.Collections.Generic;
2+
using System.Linq;
3+
using System;
4+
using static Tensorflow.OpDef.Types;
5+
6+
namespace Tensorflow.Eager
7+
{
8+
/// <summary>
9+
/// python\eager\pywrap_tfe_src.cc
10+
/// </summary>
11+
public partial class wrap_tfe_src
12+
{
13+
14+
}
15+
}

src/TensorFlowNET.Core/Eager/wrap_tfe_src.TFE_FastPathExecute.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ private static TFE_Op GetOp(Context ctx, string op_or_function_name, Status stat
110110
var maybe_op = ReleaseThreadLocalOp();
111111
if (maybe_op != IntPtr.Zero)
112112
{
113-
c_api.TFE_OpReset(ctx, op_or_function_name, status, maybe_op);
113+
c_api.TFE_OpReset(maybe_op, op_or_function_name, ctx.device_name, status);
114114
}
115115
else
116116
{

src/TensorFlowNET.Core/Gradients/GradientActor.cs

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ public class GradientActor : IDisposable
2323
bool _watch_accessed_variables;
2424
bool _created_eagerly;
2525
Tape _tape;
26-
int tape_nesting_id_counter = 0;
2726

2827
public GradientActor(bool persistent = false,
2928
bool watch_accessed_variables = true)
@@ -41,18 +40,28 @@ private void _push_tape()
4140
"re-enter an already-active tape.");
4241

4342
if (_tape == null)
44-
{
45-
_tape = new Tape();
46-
_tape.tape = new GradientTape(_persistent, _watch_accessed_variables);
47-
_tape.nesting_id = tape_nesting_id_counter++;
48-
}
43+
_tape = new Tape(_persistent, _watch_accessed_variables);
44+
else
45+
throw new NotImplementedException("");
4946

5047
_recording = true;
5148
}
5249

50+
/// <summary>
51+
/// Marks this tensor to be watched by the given tape.
52+
/// </summary>
53+
/// <param name="x"></param>
5354
public void watch(Tensor x)
5455
{
56+
_tape.watch(x);
57+
}
5558

59+
public Tensor gradient(Tensor target, Tensor sources)
60+
{
61+
c_api.TFE_Test();
62+
//using (var status = new Status())
63+
//c_api.TFE_TapeGradient(_tape, new long[] { target.Id }, status);
64+
return null;
5665
}
5766

5867
public void Dispose()

src/TensorFlowNET.Core/Gradients/Tape.cs

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,21 @@
44

55
namespace Tensorflow.Gradients
66
{
7-
public class Tape
7+
public class Tape : DisposableObject
88
{
99
public GradientTape tape { get; set; }
1010
public int nesting_id { get; set; }
1111

12+
public Tape(bool persistent, bool watch_accessed_variables)
13+
{
14+
_handle = c_api.TFE_TapeSetNew(persistent, watch_accessed_variables);
15+
}
16+
17+
public void watch(Tensor x)
18+
{
19+
c_api.TFE_TapeWatch(_handle, x, x.Id);
20+
}
21+
1222
public static bool IsDtypeTrainable(DataType dtype)
1323
{
1424
switch (dtype)
@@ -26,5 +36,12 @@ public static bool IsDtypeTrainable(DataType dtype)
2636
return false;
2737
}
2838
}
39+
40+
protected override void DisposeUnmanagedResources(IntPtr handle)
41+
{
42+
}
43+
44+
public static implicit operator IntPtr(Tape tape)
45+
=> tape._handle;
2946
}
3047
}

src/TensorFlowNET.Core/Tensors/Tensor.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ public partial class Tensor : DisposableObject,
3939
IPackable<Tensor>,
4040
ICanBeFlattened
4141
{
42-
private readonly int _id;
42+
protected int _id;
4343
private readonly Operation _op;
4444
private readonly int _value_index;
4545
private TF_Output? _tf_output;

0 commit comments

Comments
 (0)