Skip to content

Commit 9420ba3

Browse files
committed
Fix the error of loaded function model backward.
1 parent 1d1657d commit 9420ba3

45 files changed

Lines changed: 870 additions & 409 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

src/TensorFlowNET.Core/Contexts/FunctionCallOptions.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
using System.Collections.Generic;
33
using System.Text;
44
using Google.Protobuf;
5+
using Protobuf.Text;
56
using static Tensorflow.Binding;
67

78
namespace Tensorflow.Contexts

src/TensorFlowNET.Core/Eager/EagerRunner.MustRecordGradient.cs

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,36 @@ public bool MustRecordGradient()
1212
return HasGradientTape();
1313
}
1414

15-
private bool ShouldRecord(Tensor[] inputs)
15+
public int TFE_TapeSetPossibleGradientTypes(Tensor[] tensors)
1616
{
17-
bool should_record = false;
18-
foreach (var tape in tf.GetTapeSet())
17+
var tape_set = tf.GetTapeSet();
18+
var input_ids = MakeTensorIDList(tensors);
19+
var input_dtypes = MakeTensorDtypeList(tensors);
20+
bool some_tape_watching = false;
21+
if (tape_set is not null && tape_set.Count > 0)
1922
{
20-
if (tape.ShouldRecord(inputs))
23+
foreach (var tape in tape_set)
2124
{
22-
should_record = true;
23-
break;
25+
if (tape.ShouldRecord(input_ids, input_dtypes))
26+
{
27+
if (tape.Persistent || some_tape_watching)
28+
{
29+
return gradients_util.POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER;
30+
}
31+
some_tape_watching = true;
32+
}
2433
}
2534
}
26-
return should_record;
35+
// skip the forward_accumulators.
36+
37+
if (some_tape_watching)
38+
{
39+
return gradients_util.POSSIBLE_GRADIENT_TYPES_FIRST_ORDER;
40+
}
41+
else
42+
{
43+
return gradients_util.POSSIBLE_GRADIENT_TYPES_NONE;
44+
}
2745
}
2846
}
2947
}

src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,17 @@ public bool RecordGradient(string op_name,
1313
Tensor[] results,
1414
BackwardFunction backwardFunction = null)
1515
{
16-
bool should_record = ShouldRecord(inputs);
16+
var input_ids = MakeTensorIDList(inputs);
17+
var input_dtypes = MakeTensorDtypeList(inputs);
18+
bool should_record = false;
19+
foreach (var tape in tf.GetTapeSet())
20+
{
21+
if (tape.ShouldRecord(input_ids, input_dtypes))
22+
{
23+
should_record = true;
24+
break;
25+
}
26+
}
1727

1828
if (!should_record)
1929
{
@@ -59,7 +69,7 @@ public bool RecordGradient(string op_name,
5969
op_inputs = inputs;*/
6070

6171
backwardFunction = backwardFunction ?? GetGradientFunction(op_name, inputs, attrs, results);
62-
TapeSetRecordOperation(op_name, inputs, results, backwardFunction);
72+
TapeSetRecordOperation(op_name, inputs, results, input_ids, input_dtypes, backwardFunction);
6373

6474
return true;
6575
}
@@ -129,10 +139,5 @@ bool CouldBackprop()
129139
{
130140
return HasGradientTape();
131141
}
132-
133-
TF_DataType[] MakeTensorDtypeList(Tensor[] tensors)
134-
{
135-
return tensors.Select(x => x.dtype).ToArray();
136-
}
137142
}
138143
}
Lines changed: 162 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1-
using System;
1+
using OneOf.Types;
2+
using System;
23
using Tensorflow.Gradients;
34
using Tensorflow.Util;
5+
using static Tensorflow.Binding;
46

57
namespace Tensorflow.Eager
68
{
@@ -9,40 +11,183 @@ namespace Tensorflow.Eager
911
/// </summary>
1012
public partial class EagerRunner
1113
{
14+
/// <summary>
15+
///
16+
/// </summary>
17+
/// <param name="tape"></param>
18+
/// <param name="target"></param>
19+
/// <param name="sources"></param>
20+
/// <param name="output_gradients"></param>
21+
/// <param name="unconnected_gradients">determines the value returned if the target and
22+
/// sources are unconnected.When 'none' the value returned is None wheras when
23+
/// 'zero' a zero tensor in the same shape as the sources is returned.</param>
24+
/// <returns></returns>
25+
/// <exception cref="RuntimeError"></exception>
1226
public Tensor[] TFE_TapeGradient(ITape tape,
1327
Tensor[] target,
1428
Tensor[] sources,
15-
Tensor[] output_gradients)
29+
List<Tensor> output_gradients,
30+
Tensor[] sources_raw,
31+
string unconnected_gradients)
1632
{
17-
var target_vec = target;
18-
var sources_vec = sources;
19-
var sources_set = sources_vec;
33+
if (!tape.Persistent)
34+
{
35+
var tape_set = tf.GetTapeSet();
36+
if (tape_set.Contains(tape))
37+
{
38+
throw new RuntimeError("gradient() cannot be invoked within the " +
39+
"GradientTape context (i.e., while operations are being " +
40+
"recorded). Either move the call to gradient() to be " +
41+
"outside the 'with tf.GradientTape' block, or " +
42+
"use a persistent tape: " +
43+
"'with tf.GradientTape(persistent=true)'");
44+
}
45+
}
46+
47+
var target_vec = MakeTensorIDList(target);
48+
var sources_vec = MakeTensorIDList(sources);
49+
HashSet<long> sources_set = new HashSet<long>(sources_vec);
50+
var source_tensors_that_are_targets = new UnorderedMap<long, TapeTensor>();
51+
52+
int len = target.Length;
53+
for(int i = 0; i < len; i++)
54+
{
55+
var target_id = target_vec[i];
56+
if (sources_set.Contains(target_id))
57+
{
58+
var tensor = target[i];
59+
source_tensors_that_are_targets[target_id] = TapeTensorFromTensor(tensor);
60+
}
61+
}
62+
63+
List<Tensor> outgrad_vec = new();
64+
if(output_gradients is not null)
65+
{
66+
outgrad_vec = output_gradients.ToList();
67+
}
68+
var result = tape.ComputeGradient(target_vec, sources_vec, source_tensors_that_are_targets, outgrad_vec, false);
2069

21-
var seq_array = target;
22-
var source_tensors_that_are_targets = new UnorderedMap<Tensor, TapeTensor>();
2370

24-
for (int i = 0; i < target.Length; ++i)
71+
bool unconnected_gradients_zero = unconnected_gradients == "zero";
72+
Tensor[] sources_obj = null;
73+
if (unconnected_gradients_zero)
2574
{
26-
source_tensors_that_are_targets.Add(target_vec[i], new TapeTensor(seq_array[i]));
75+
sources_obj = MakeTensorList(sources_raw);
2776
}
2877

29-
if (output_gradients != null)
78+
if (result.Length > 0)
3079
{
31-
throw new NotImplementedException("");
80+
for(int i = 0; i < result.Length; i++)
81+
{
82+
if (result[i] is null && unconnected_gradients_zero)
83+
{
84+
var dtype = sources_obj[i].dtype;
85+
result[i] = new TapeTensor(sources_vec[i], dtype, sources_obj[i]).ZerosLike();
86+
}
87+
}
3288
}
33-
else
89+
return result;
90+
}
91+
92+
Tensor[] MakeTensorList(IEnumerable<Tensor> tensors)
93+
{
94+
return tensors.ToArray();
95+
}
96+
97+
long[] MakeTensorIDList(Tensor[] tensors)
98+
{
99+
int len = tensors.Length;
100+
long[] ids = new long[len];
101+
for(int i = 0; i < len; i++)
102+
{
103+
var tensor = tensors[i];
104+
ids[i] = tensor.Id;
105+
}
106+
return ids;
107+
}
108+
109+
TF_DataType[] MakeTensorDtypeList(Tensor[] tensors)
110+
{
111+
int len = tensors.Length;
112+
TF_DataType[] dtypes = new TF_DataType[len];
113+
for (int i = 0; i < len; i++)
34114
{
35-
output_gradients = new Tensor[0];
115+
var tensor = tensors[i];
116+
dtypes[i] = tensor.dtype;
36117
}
118+
return dtypes;
119+
}
37120

38-
var outgrad_vec = MakeTensorList(output_gradients);
121+
TapeTensor TapeTensorFromTensor(Tensor tensor)
122+
{
123+
long id = tensor.Id;
124+
var dtype = tensor.dtype;
125+
if (tensor is EagerTensor)
126+
{
127+
var handle = tensor.EagerTensorHandle;
128+
if (DTypeNeedsHandleData(dtype))
129+
{
130+
return new TapeTensor(id, c_api.TFE_TensorHandleDataType(handle), tensor);
131+
}
132+
133+
Status status = new();
134+
int num_dims = c_api.TFE_TensorHandleNumDims(handle, status);
135+
long[] dims = new long[num_dims];
136+
for(int i = 0; i < num_dims; i++)
137+
{
138+
dims[i] = c_api.TFE_TensorHandleDim(handle, i, status);
139+
}
140+
Shape tensor_shape = new(dims);
141+
142+
if(status.Code != TF_Code.TF_OK)
143+
{
144+
return new TapeTensor(id, TF_DataType.DtInvalid, Shape.Null);
145+
}
146+
else
147+
{
148+
return new TapeTensor(id, dtype, tensor_shape);
149+
}
150+
}
151+
var shape_tuple = tensor.shape.dims;
152+
if(ListContainNone(shape_tuple) || DTypeNeedsHandleData(dtype))
153+
{
154+
return new TapeTensor(id, dtype, tensor);
155+
}
156+
long[] l = new long[shape_tuple.Length];
157+
for(int i = 0; i < shape_tuple.Length; i++)
158+
{
159+
if (shape_tuple[i] < 0)
160+
{
161+
l[i] = 0;
162+
}
163+
else
164+
{
165+
l[i] = shape_tuple[i];
166+
}
167+
}
168+
return new TapeTensor(id, dtype, new Shape(l));
169+
}
39170

40-
return tape.ComputeGradient(target_vec, sources_vec, source_tensors_that_are_targets, outgrad_vec);
171+
bool DTypeNeedsHandleData(TF_DataType dtype)
172+
{
173+
return dtype == dtypes.variant || dtype == dtypes.resource;
41174
}
42175

43-
Tensor[] MakeTensorList(Tensor[] tensors)
176+
bool ListContainNone(long[] list)
44177
{
45-
return tensors;
178+
int len = list.Length;
179+
if(len == 0)
180+
{
181+
return true;
182+
}
183+
for(int i = 0; i < len; i++)
184+
{
185+
if (list[i] == -1)
186+
{
187+
return true;
188+
}
189+
}
190+
return false;
46191
}
47192
}
48193
}

src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordBackprop.cs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@ namespace Tensorflow.Eager
77
public partial class EagerRunner
88
{
99
void TapeSetRecordBackprop(string op_type,
10-
Tensor[] input_tensors,
11-
TapeTensor[] output_tensors,
10+
TapeTensor[] output_info,
11+
long[] input_ids,
12+
TF_DataType[] input_detyps,
1213
BackwardFunction backward_function)
1314
{
1415
if (!CouldBackprop())
@@ -18,7 +19,7 @@ void TapeSetRecordBackprop(string op_type,
1819

1920
foreach (var tape in tf.GetTapeSet())
2021
{
21-
tape.RecordOperation(op_type, input_tensors, output_tensors, backward_function);
22+
tape.RecordOperation(op_type, output_info, input_ids, input_detyps, backward_function);
2223
}
2324
}
2425
}

src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetRecordOperation.cs

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,28 @@ public partial class EagerRunner
1010
public bool TapeSetRecordOperation(string op_type,
1111
Tensor[] input_tensors,
1212
Tensor[] output_tensors,
13+
long[] input_ids,
14+
TF_DataType[] input_dtypes,
1315
BackwardFunction backward_function)
1416
{
15-
var output_info = output_tensors.Select(x => new TapeTensor(x)).ToArray();
16-
17+
var output_info = output_tensors.Select(t => TapeTensorFromTensor(t)).ToArray();
1718
if (!TapeSetRecordForwardprop(op_type, input_tensors, output_info,
1819
backward_function))
1920
return false;
2021

21-
TapeSetRecordBackprop(op_type, input_tensors, output_info,
22+
TapeSetRecordBackprop(op_type, output_info, input_ids, input_dtypes,
2223
backward_function);
2324

2425
return true;
2526
}
27+
28+
public void TFE_TapeSetRecordOperation(string op_type, Tensor[] output_tensors,
29+
Tensor[] input_tensors, BackwardFunction backward_function)
30+
{
31+
var input_ids = MakeTensorIDList(input_tensors);
32+
var input_dtypes = MakeTensorDtypeList(input_tensors);
33+
TapeSetRecordOperation(op_type, input_tensors, output_tensors, input_ids, input_dtypes,
34+
backward_function);
35+
}
2636
}
2737
}

src/TensorFlowNET.Core/Eager/IEagerRunner.cs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,14 @@ Tensor[] TFE_Execute(Context ctx,
2929
Tensor[] TFE_TapeGradient(ITape tape,
3030
Tensor[] target,
3131
Tensor[] sources,
32-
Tensor[] output_gradients);
32+
List<Tensor> output_gradients,
33+
Tensor[] sources_raw,
34+
string unconnected_gradients);
35+
36+
void TFE_TapeSetRecordOperation(string op_type, Tensor[] output_tensors,
37+
Tensor[] input_tensors, BackwardFunction backward_function);
38+
39+
int TFE_TapeSetPossibleGradientTypes(Tensor[] tensors);
3340

3441
bool RecordGradient(string op_name,
3542
Tensor[] inputs,

0 commit comments

Comments
 (0)