Skip to content

Commit ca6f8b2

Browse files
committed
TFE_TapeVariableAccessed
1 parent 3f8b658 commit ca6f8b2

15 files changed

Lines changed: 252 additions & 64 deletions

src/TensorFlowNET.Core/APIs/tf.math.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ public Tensor divide<T>(Tensor x, T[] y, string name = null) where T : struct
390390
=> x / ops.convert_to_tensor(y, dtype: x.dtype.as_base_dtype(), name: "y");
391391

392392
public Tensor pow<T1, T2>(T1 x, T2 y, string name = "pow")
393-
=> gen_math_ops.pow(x, y, name: name);
393+
=> math_ops.pow(x, y, name: name);
394394

395395
/// <summary>
396396
/// Divides `x / y` elementwise, rounding toward the most negative integer.

src/TensorFlowNET.Core/Eager/EagerTensor.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@ public override string ToString()
5353

5454
public static string GetFormattedString(TF_DataType dtype, NDArray nd)
5555
{
56+
if (nd.size == 0)
57+
return "[]";
58+
5659
switch (dtype)
5760
{
5861
case TF_DataType.TF_STRING:

src/TensorFlowNET.Core/Eager/c_api.eager.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,9 @@ public static extern IntPtr TFE_QuickExecute(IntPtr ctx,
375375
[DllImport(TensorFlowLibName)]
376376
public static extern void TFE_TapeWatch(IntPtr tape, IntPtr tensor);
377377

378+
[DllImport(TensorFlowLibName)]
379+
public static extern void TFE_TapeVariableAccessed(IntPtr variable);
380+
378381
[DllImport(TensorFlowLibName)]
379382
public static extern IntPtr TFE_TapeGradient(IntPtr tape,
380383
IntPtr[] target, int target_size,

src/TensorFlowNET.Core/Gradients/GradientActor.cs

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using System;
22
using System.Collections.Generic;
3+
using System.Linq;
34
using System.Text;
45
using Tensorflow.Eager;
56
using static Tensorflow.Binding;
@@ -65,7 +66,7 @@ public void watch(Tensor x)
6566
_tape.watch(x as EagerTensor);
6667
}
6768

68-
public Tensor gradient(Tensor target, Tensor sources)
69+
public Tensor gradient(Tensor target, Tensor source)
6970
{
7071
if(_recording)
7172
{
@@ -76,15 +77,33 @@ public Tensor gradient(Tensor target, Tensor sources)
7677
using var status = new Status();
7778
var et = c_api.TFE_TapeGradient(_tape,
7879
new [] { (target as EagerTensor).EagerTensorHandle }, 1,
79-
new [] { (sources as EagerTensor).EagerTensorHandle }, 1,
80+
new [] { (source as EagerTensor).EagerTensorHandle }, 1,
8081
status);
8182
status.Check(true);
8283
return new EagerTensor(et);
8384
}
8485

86+
public Tensor gradient(Tensor target, ResourceVariable[] sources)
87+
{
88+
if (_recording)
89+
{
90+
if (!_persistent)
91+
_pop_tape();
92+
}
93+
94+
using var status = new Status();
95+
EagerTensorHandle et = c_api.TFE_TapeGradient(_tape,
96+
new[] { (target as EagerTensor).EagerTensorHandle }, 1,
97+
sources.Select(x => (x.handle as EagerTensor).EagerTensorHandle).ToArray(), sources.Length,
98+
status);
99+
status.Check(true);
100+
return et;
101+
}
102+
85103
public void Dispose()
86104
{
87-
105+
if (_recording)
106+
_pop_tape();
88107
}
89108
}
90109
}

src/TensorFlowNET.Core/Gradients/Tape.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,11 @@ public void pop_tape(Tape tape)
2525
c_api.TFE_TapeSetRemove(tape);
2626
}
2727

28+
public static void variable_accessed(ResourceVariable variable)
29+
{
30+
c_api.TFE_TapeVariableAccessed(variable.handle as EagerTensor);
31+
}
32+
2833
public static bool IsDtypeTrainable(DataType dtype)
2934
{
3035
switch (dtype)

src/TensorFlowNET.Core/Operations/gen_array_ops.cs

Lines changed: 59 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,18 @@ public static Tensor prevent_gradient(Tensor input, string message = "", string
220220
/// <param name="name"></param>
221221
public static Tensor identity(Tensor input, string name = null)
222222
{
223+
if (tf.context.executing_eagerly())
224+
{
225+
using var status = new Status();
226+
EagerTensorHandle tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name,
227+
"Identity", name, new IntPtr[]
228+
{
229+
input as EagerTensor
230+
}, 1, null, status);
231+
status.Check(true);
232+
return tensor;
233+
}
234+
223235
var _op = _op_def_lib._apply_op_helper("Identity", name, new { input });
224236

225237
return _op.output;
@@ -258,14 +270,14 @@ public static Tensor fill<T>(Tensor dims, T value, string name = null)
258270
if (tf.context.executing_eagerly())
259271
{
260272
using var status = new Status();
261-
var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name,
273+
EagerTensorHandle tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name,
262274
"Fill", name, new IntPtr[]
263275
{
264276
dims as EagerTensor,
265277
value as EagerTensor
266278
}, 2, null, status);
267279
status.Check(true);
268-
return new EagerTensor(tensor);
280+
return tensor;
269281
}
270282

271283
var _op = _op_def_lib._apply_op_helper("Fill", name, new { dims, value });
@@ -281,6 +293,18 @@ value as EagerTensor
281293
/// <returns>A tuple of `Tensor` objects (r0, r1).</returns>
282294
public static (Tensor, Tensor) broadcast_gradient_args(Tensor s0, Tensor s1, string name = "")
283295
{
296+
if (tf.context.executing_eagerly())
297+
{
298+
using var status = new Status();
299+
var _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name,
300+
"BroadcastGradientArgs", name, new IntPtr[]
301+
{
302+
s0 as EagerTensor,
303+
s1 as EagerTensor
304+
}, 2, null, status);
305+
status.Check(true);
306+
}
307+
284308
var _op = _op_def_lib._apply_op_helper("BroadcastGradientArgs", name, new { s0, s1 });
285309

286310
return (_op.outputs[0], _op.outputs[1]);
@@ -371,10 +395,19 @@ public static Tensor shape(Tensor input, TF_DataType out_type = TF_DataType.TF_I
371395
{
372396
if (tf.context.executing_eagerly())
373397
{
374-
var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name,
375-
"Shape", name, null,
376-
input, "out_type", out_type);
377-
return _result;
398+
using var status = new Status();
399+
EagerTensorHandle tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name,
400+
"Shape", name, new IntPtr[]
401+
{
402+
input as EagerTensor,
403+
}, 1,
404+
op => wrap_tfe_src.SetOpAttrs(tf.context, op, new object[]
405+
{
406+
"out_type", out_type
407+
}, status),
408+
status);
409+
status.Check(true);
410+
return tensor;
378411
}
379412

380413
var _op = _op_def_lib._apply_op_helper("Shape", name, new { input, out_type });
@@ -455,12 +488,26 @@ public static Tensor strided_slice(Tensor input, Tensor begin, Tensor end, Tenso
455488
{
456489
if (tf.context.executing_eagerly())
457490
{
458-
var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name,
459-
"StridedSlice", name, null,
460-
input, begin, end, strides, "begin_mask", begin_mask,
461-
"end_mask", end_mask, "ellipsis_mask", ellipsis_mask,
462-
"new_axis_mask", new_axis_mask, "shrink_axis_mask", shrink_axis_mask);
463-
return _result;
491+
using var status = new Status();
492+
EagerTensorHandle tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name,
493+
"StridedSlice", name, new IntPtr[]
494+
{
495+
input as EagerTensor,
496+
begin as EagerTensor,
497+
end as EagerTensor,
498+
strides as EagerTensor,
499+
}, 4,
500+
op => wrap_tfe_src.SetOpAttrs(tf.context, op, new object[]
501+
{
502+
"begin_mask", begin_mask,
503+
"end_mask", end_mask,
504+
"ellipsis_mask", ellipsis_mask,
505+
"new_axis_mask", new_axis_mask,
506+
"shrink_axis_mask", shrink_axis_mask
507+
}, status),
508+
status);
509+
status.Check(true);
510+
return tensor;
464511
}
465512

466513
var _op = _op_def_lib._apply_op_helper("StridedSlice", name, new

0 commit comments

Comments
 (0)