Skip to content

Commit 2f3bd61

Browse files
committed
GradientActor
1 parent 7764865 commit 2f3bd61

10 files changed

Lines changed: 126 additions & 25 deletions

File tree

src/TensorFlowNET.Core/APIs/tf.math.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ You may obtain a copy of the License at
1414
limitations under the License.
1515
******************************************************************************/
1616

17+
using Tensorflow.Eager;
1718
using Tensorflow.Operations;
1819

1920
namespace Tensorflow
@@ -259,7 +260,6 @@ public Tensor clip_by_value (Tensor t, Tensor clip_value_min, Tensor clip_value_
259260
public Tensor sub<Tx, Ty>(Tx a, Ty b, string name = null)
260261
=> gen_math_ops.sub(a, b, name: name);
261262

262-
263263
public Tensor divide(Tensor a, Tensor b)
264264
=> a / b;
265265

@@ -348,6 +348,9 @@ public Tensor maximum<T1, T2>(T1 x, T2 y, string name = null)
348348
public Tensor minimum<T1, T2>(T1 x, T2 y, string name = null)
349349
=> gen_math_ops.minimum(x, y, name: name);
350350

351+
public Tensor multiply(Tensor x, Tensor y, string name = null)
352+
=> gen_math_ops.mul(x, y, name: name);
353+
351354
/// <summary>
352355
/// return x * y
353356
/// </summary>

src/TensorFlowNET.Core/Eager/EagerTensor.cs

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,48 +2,52 @@
22
using System;
33
using System.Collections.Generic;
44
using System.Text;
5+
using static Tensorflow.Binding;
56

67
namespace Tensorflow.Eager
78
{
89
public partial class EagerTensor : Tensor
910
{
1011
Status status = new Status();
1112
TFE_TensorHandle tfe_tensor_handle;
13+
public IntPtr EagerTensorHandle { get; set; }
14+
1215
public EagerTensor(IntPtr handle) : base(handle)
1316
{
1417
tfe_tensor_handle = handle;
1518
_handle = c_api.TFE_TensorHandleResolve(handle, status);
16-
_id = ops.uid();
1719
}
1820

1921
public EagerTensor(string value, string device_name) : base(value)
2022
{
2123
tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status);
22-
_id = ops.uid();
2324
}
2425

2526
public EagerTensor(int value, string device_name) : base(value)
2627
{
2728
tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status);
28-
_id = ops.uid();
29+
EagerTensorHandle = c_api.TFE_EagerTensorFromHandle(tf.context, tfe_tensor_handle);
30+
}
31+
32+
public EagerTensor(float value, string device_name) : base(value)
33+
{
34+
tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status);
35+
EagerTensorHandle = c_api.TFE_EagerTensorFromHandle(tf.context, tfe_tensor_handle);
2936
}
3037

3138
public EagerTensor(float[] value, string device_name) : base(value)
3239
{
3340
tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status);
34-
_id = ops.uid();
3541
}
3642

3743
public EagerTensor(double[] value, string device_name) : base(value)
3844
{
3945
tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status);
40-
_id = ops.uid();
4146
}
4247

4348
public EagerTensor(NDArray value, string device_name) : base(value)
4449
{
4550
tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status);
46-
_id = ops.uid();
4751
}
4852

4953
public override string ToString()

src/TensorFlowNET.Core/Eager/c_api.eager.cs

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,12 @@ namespace Tensorflow
77
{
88
public partial class c_api
99
{
10+
[DllImport(TensorFlowLibName)]
11+
public static extern void TFE_RegisterGradientFunction(_gradient_function_callback callbackPointer);
12+
13+
[UnmanagedFunctionPointer(CallingConvention.StdCall)]
14+
public delegate void _gradient_function_callback(string op_name, int num_inputs, IntPtr attrs, int num_attrs);
15+
1016
/// <summary>
1117
/// Return a new options object.
1218
/// </summary>
@@ -186,6 +192,9 @@ public partial class c_api
186192
[DllImport(TensorFlowLibName)]
187193
public static extern TFE_TensorHandle TFE_NewTensorHandle(IntPtr t, IntPtr status);
188194

195+
[DllImport(TensorFlowLibName)]
196+
public static extern TFE_TensorHandle TFE_EagerTensorFromHandle(IntPtr ctx, IntPtr h);
197+
189198
/// <summary>
190199
/// Sets the default execution mode (sync/async). Note that this can be
191200
/// overridden per thread using TFE_ContextSetExecutorForThread.
@@ -312,15 +321,21 @@ public partial class c_api
312321
public static extern TFE_Executor TFE_ContextGetExecutorForThread(IntPtr ctx);
313322

314323
[DllImport(TensorFlowLibName)]
315-
public static extern void TFE_Test();
324+
public static extern IntPtr TFE_FastPathExecute(IntPtr ctx,
325+
string device_name,
326+
string op_name,
327+
string name,
328+
IntPtr[] args,
329+
int input_size,
330+
IntPtr status);
316331

317332
[DllImport(TensorFlowLibName)]
318333
public static extern IntPtr TFE_TapeSetNew(bool persistent, bool watch_accessed_variables);
319334

320335
[DllImport(TensorFlowLibName)]
321-
public static extern void TFE_TapeWatch(IntPtr tape, IntPtr tensor, int tensor_id);
336+
public static extern void TFE_TapeWatch(IntPtr tape, IntPtr tensor);
322337

323338
[DllImport(TensorFlowLibName)]
324-
public static extern void TFE_TapeGradient(IntPtr tape, long[] targetTensorIds, IntPtr[] target, long[] sourcesTensorIds, IntPtr status);
339+
public static extern void TFE_TapeGradient(IntPtr tape, IntPtr[] target, IntPtr sources, IntPtr status);
325340
}
326341
}

src/TensorFlowNET.Core/Gradients/GradientActor.cs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using System;
22
using System.Collections.Generic;
33
using System.Text;
4+
using Tensorflow.Eager;
45
using static Tensorflow.Binding;
56

67
namespace Tensorflow.Gradients
@@ -53,14 +54,16 @@ private void _push_tape()
5354
/// <param name="x"></param>
5455
public void watch(Tensor x)
5556
{
56-
_tape.watch(x);
57+
_tape.watch(x as EagerTensor);
5758
}
5859

5960
public Tensor gradient(Tensor target, Tensor sources)
6061
{
61-
c_api.TFE_Test();
62-
//using (var status = new Status())
63-
//c_api.TFE_TapeGradient(_tape, new long[] { target.Id }, status);
62+
using (var status = new Status())
63+
{
64+
c_api.TFE_TapeGradient(_tape, new IntPtr[] { target }, IntPtr.Zero, status);
65+
}
66+
6467
return null;
6568
}
6669

src/TensorFlowNET.Core/Gradients/Tape.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using System;
22
using System.Collections.Generic;
33
using System.Text;
4+
using Tensorflow.Eager;
45

56
namespace Tensorflow.Gradients
67
{
@@ -14,9 +15,9 @@ public Tape(bool persistent, bool watch_accessed_variables)
1415
_handle = c_api.TFE_TapeSetNew(persistent, watch_accessed_variables);
1516
}
1617

17-
public void watch(Tensor x)
18+
public void watch(EagerTensor x)
1819
{
19-
c_api.TFE_TapeWatch(_handle, x, x.Id);
20+
c_api.TFE_TapeWatch(_handle, x.EagerTensorHandle);
2021
}
2122

2223
public static bool IsDtypeTrainable(DataType dtype)

src/TensorFlowNET.Core/Operations/gen_math_ops.cs

Lines changed: 77 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,28 @@ public static Tensor asin(Tensor x, string name = null)
192192
return _op.outputs[0];
193193
}
194194

195+
public static Tensor add(Tensor x, Tensor y, string name = null)
196+
{
197+
if (tf.context.executing_eagerly())
198+
{
199+
using (var status = new Status())
200+
{
201+
var _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name,
202+
"Add", name, new IntPtr[]
203+
{
204+
(x as EagerTensor).EagerTensorHandle,
205+
(y as EagerTensor).EagerTensorHandle
206+
}, 2, status);
207+
status.Check(true);
208+
return new EagerTensor(_result);
209+
}
210+
}
211+
212+
var _op = _op_def_lib._apply_op_helper("Add", name, args: new { x, y });
213+
214+
return _op.output;
215+
}
216+
195217
public static Tensor add<Tx, Ty>(Tx x, Ty y, string name = null)
196218
{
197219
if (tf.context.executing_eagerly())
@@ -593,6 +615,28 @@ public static Tensor sqrt(Tensor x, string name = null)
593615
return _op.outputs[0];
594616
}
595617

618+
public static Tensor sub(Tensor x, Tensor y, string name = null)
619+
{
620+
if (tf.context.executing_eagerly())
621+
{
622+
using (var status = new Status())
623+
{
624+
var _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name,
625+
"Sub", name, new IntPtr[]
626+
{
627+
(x as EagerTensor).EagerTensorHandle,
628+
(y as EagerTensor).EagerTensorHandle
629+
}, 2, status);
630+
status.Check(true);
631+
return new EagerTensor(_result);
632+
}
633+
}
634+
635+
var _op = _op_def_lib._apply_op_helper("Sub", name, args: new { x, y });
636+
637+
return _op.output;
638+
}
639+
596640
public static Tensor sub<Tx, Ty>(Tx x, Ty y, string name = null)
597641
{
598642
if (tf.context.executing_eagerly())
@@ -667,6 +711,28 @@ public static Tensor atan2(Tensor y, Tensor x, string name = null)
667711
return _op.output;
668712
}
669713

714+
public static Tensor mul(Tensor x, Tensor y, string name = null)
715+
{
716+
if (tf.context.executing_eagerly())
717+
{
718+
using (var status = new Status())
719+
{
720+
var _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name,
721+
"Mul", name, new IntPtr[]
722+
{
723+
(x as EagerTensor).EagerTensorHandle,
724+
(y as EagerTensor).EagerTensorHandle
725+
}, 2, status);
726+
status.Check(true);
727+
return new EagerTensor(_result);
728+
}
729+
}
730+
731+
var _op = _op_def_lib._apply_op_helper("Mul", name, args: new { x, y });
732+
733+
return _op.output;
734+
}
735+
670736
public static Tensor mul<Tx, Ty>(Tx x, Ty y, string name = null)
671737
{
672738
if (tf.context.executing_eagerly())
@@ -693,8 +759,17 @@ public static Tensor real_div(Tensor x, Tensor y, string name = null)
693759
{
694760
if (tf.context.executing_eagerly())
695761
{
696-
var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, "", "RealDiv", name, null, x, y);
697-
return _result;
762+
using (var status = new Status())
763+
{
764+
var _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name,
765+
"RealDiv", name, new IntPtr[]
766+
{
767+
(x as EagerTensor).EagerTensorHandle,
768+
(y as EagerTensor).EagerTensorHandle
769+
}, 2, status);
770+
status.Check(true);
771+
return new EagerTensor(_result);
772+
}
698773
}
699774

700775
var _op = _op_def_lib._apply_op_helper("RealDiv", name, args: new { x, y });

src/TensorFlowNET.Core/TensorFlow.Binding.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
<TargetFramework>netstandard2.0</TargetFramework>
55
<AssemblyName>TensorFlow.NET</AssemblyName>
66
<RootNamespace>Tensorflow</RootNamespace>
7-
<TargetTensorFlow>2.01.0</TargetTensorFlow>
7+
<TargetTensorFlow>2.2.0</TargetTensorFlow>
88
<Version>0.20.0</Version>
99
<LangVersion>8.0</LangVersion>
1010
<Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors>

src/TensorFlowNet.Benchmarks/Tensorflow.Benchmark.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
</ItemGroup>
1919

2020
<ItemGroup>
21-
<PackageReference Include="BenchmarkDotNet" Version="0.12.0" />
21+
<PackageReference Include="BenchmarkDotNet" Version="0.12.1" />
2222
<PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.1.0" />
2323
<PackageReference Include="TensorFlow.NET" Version="0.15.1" />
2424
</ItemGroup>

test/TensorFlowNET.UnitTest/Tensorflow.UnitTest.csproj

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@
3131
<ItemGroup>
3232
<PackageReference Include="FluentAssertions" Version="5.10.3" />
3333
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.5.0" />
34-
<PackageReference Include="MSTest.TestAdapter" Version="2.1.0" />
35-
<PackageReference Include="MSTest.TestFramework" Version="2.1.0" />
34+
<PackageReference Include="MSTest.TestAdapter" Version="2.1.1" />
35+
<PackageReference Include="MSTest.TestFramework" Version="2.1.1" />
3636
<PackageReference Include="NumSharp.Lite" Version="0.1.7" />
3737
<PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.1.0" />
3838
</ItemGroup>

test/Tensorflow.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88

99
<ItemGroup>
1010
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.5.0" />
11-
<PackageReference Include="MSTest.TestAdapter" Version="2.1.0" />
12-
<PackageReference Include="MSTest.TestFramework" Version="2.1.0" />
13-
<PackageReference Include="coverlet.collector" Version="1.2.0">
11+
<PackageReference Include="MSTest.TestAdapter" Version="2.1.1" />
12+
<PackageReference Include="MSTest.TestFramework" Version="2.1.1" />
13+
<PackageReference Include="coverlet.collector" Version="1.2.1">
1414
<PrivateAssets>all</PrivateAssets>
1515
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
1616
</PackageReference>

0 commit comments

Comments
 (0)