forked from SciSharp/TensorFlow.NET
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy patharray_grad.cs
More file actions
379 lines (337 loc) · 15.2 KB
/
array_grad.cs
File metadata and controls
379 lines (337 loc) · 15.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Eager;
using Tensorflow.Framework;
using Tensorflow.NumPy;
using static Tensorflow.Binding;
namespace Tensorflow.Gradients
{
/// <summary>
/// tensorflow\python\ops\array_grad.py
/// </summary>
[RegisterGradient("array_grad")]
public class array_grad
{
[RegisterGradient("BroadcastTo")]
public static Tensor[] _BroadcastToGrad(Operation op, Tensor[] grads)
{
var grad = grads[0];
var input_value = op.inputs[0];
var broadcast_shape = op.inputs[1];
var input_value_shape = array_ops.shape(input_value);
var (_, reduction_axes) = gen_array_ops.broadcast_gradient_args(broadcast_shape,
input_value_shape);
var updates_grad_reshaped = math_ops.reduce_sum(grad,
axis: reduction_axes,
keepdims: true);
var updates_grad = array_ops.reshape(updates_grad_reshaped, input_value_shape);
return new Tensor[]
{
updates_grad,
null
};
}
[RegisterGradient("ConcatV2")]
public static Tensor[] _ConcatV2Grad(Operation op, Tensor[] grads)
{
var grad = grads[0];
return _ConcatGradHelper(op, grad, start_value_index: 0, end_value_index: -1, dim_index: -1);
}
/// <summary>
/// Gradient for concat op.
/// </summary>
/// <param name="op">An operation.</param>
/// <param name="grad">
/// `Tensor` or `IndexedSlices` representing the gradients with respect
/// to each output of the op.
/// </param>
/// <param name="start_value_index">An integer index of the first value in the op.inputs.</param>
/// <param name="end_value_index">An integer index of the last value in the op.inputs.</param>
/// <param name="dim_index">An interger index of concat_dim or axis parameter in op.inputs.</param>
/// <returns>
/// Tensors representing the partial gradients with respect to each input
/// of the op.
/// </returns>
private static Tensor[] _ConcatGradHelper(Operation op, Tensor grad, int start_value_index, int end_value_index, int dim_index)
{
// Degenerate concatenation, just return grad.
if (len(op.inputs) == 2)
return end_value_index <= dim_index ? new Tensor[] { grad, null } : new Tensor[] { null, grad };
var concat_dim = op.inputs[dim_index];
var input_values = op.inputs._inputs.Skip(start_value_index)
.Take(end_value_index == -1 ? op.inputs.Length - 1 : end_value_index - start_value_index)
.ToArray();
var out_grads = new List<Tensor>();
if(concat_dim is EagerTensor)
{
var dim_int = (int)concat_dim;
var non_neg_concat_dim = dim_int < 0
? input_values[0].rank + dim_int
: dim_int % input_values[0].rank;
var sizes = input_values.Select(x => x.shape[non_neg_concat_dim]).ToArray();
var sizes_tensor = constant_op.constant(sizes);
out_grads = array_ops.split(grad, sizes_tensor, non_neg_concat_dim).ToList();
}
else if (constant_op.is_constant(concat_dim))
{
/*If concat_dim is a constant defined in a different context,
then we duplicate it in the current context to avoid passing it
through an Enter node.
This is a small optimization in general, but it is required when
compiling with XLA, as XLA needs the concat input to be folded into a
constant.*/
var grad_context = control_flow_util.GetOutputContext(grad.op);
var dim_context = control_flow_util.GetOutputContext(concat_dim.op);
if (dim_context != grad_context)
{
var value = tensor_util.constant_value(concat_dim);
concat_dim = constant_op.constant(value: value, dtype: concat_dim.dtype);
}
// Using mod here for convenience since concat_dim is already verified
// in concat implementation to be within the allowed [-rank, rank) range.
var non_neg_concat_dim = concat_dim % array_ops.rank(input_values[0]);
// Get the inputs' tensor shapes
var sizes = _ExtractInputShapes(input_values);
/* The magic number of 16 was found through benchmarking a range of sizes
on CPUs and a Maxwell TitanX. A speedup was seen in a large majority of
cases when switching implementations at N=16, but it is possible that
there will be a small number of performance regressions.*/
if (len(sizes) > 16)
{
// extract the size of each input along the concat dimension
var slice = array_ops.slice(array_ops.stack(sizes, axis: 1),
new Tensor[] { non_neg_concat_dim, tf.constant(0) },
new Tensor[] { tf.constant(1), tf.constant(-1) });
var squeeze_sizes = array_ops.squeeze(slice);
out_grads = array_ops.split(axis: grad, value: squeeze_sizes, num_split: (int)non_neg_concat_dim).ToList();
}
else
{
var offset = gen_array_ops.concat_offset(non_neg_concat_dim, sizes);
foreach (var (begin, size) in zip(offset, sizes))
out_grads.Add(gen_array_ops.slice(grad, begin, size));
}
}
return (end_value_index <= dim_index ?
out_grads.ToArray().Concat(new Tensor[] { null }) :
new Tensor[] { null }.Concat(out_grads)).ToArray();
}
[RegisterGradient("ExpandDims")]
public static Tensor[] _ExpandDimsGrad(Operation op, Tensor[] grads)
{
return new Tensor[] { _ReshapeToInput(op, grads[0]), null };
}
/// <summary>
/// Extract the shapes of a set of input tensors.
/// </summary>
/// <param name="inputs"></param>
/// <returns></returns>
private static Tensor[] _ExtractInputShapes(Tensor[] inputs)
{
var sizes = new Tensor[inputs.Length];
bool fully_known = true;
for (int i = 0; i < inputs.Length; i++)
{
var x = inputs[i];
var input_shape = array_ops.shape(x);
if (!(input_shape is Tensor) || input_shape.op.type != "Const")
{
fully_known = false;
break;
}
sizes[i] = input_shape;
}
if (fully_known)
return sizes;
else
return gen_array_ops.shape_n(inputs);
}
/// <summary>
/// Gradient for GatherV2 op.
/// </summary>
/// <param name="op"></param>
/// <param name="grads"></param>
/// <returns></returns>
[RegisterGradient("GatherV2")]
public static Tensor[] _GatherV2Grad(Operation op, Tensor[] grads)
{
var grad = grads[0];
var @params = op.inputs[0];
ops.colocate_with(@params);
var params_shape = array_ops.shape(@params, out_type: tf.int64);
params_shape = math_ops.cast(params_shape, tf.int32);
var indices = op.inputs[1];
var indices_size = array_ops.expand_dims(array_ops.size(indices), 0);
var axis = op.inputs[2];
var axis_static = tensor_util.constant_value(axis);
// For axis 0 gathers, build an appropriately shaped IndexedSlices.
if ((int)axis_static == 0)
{
var params_tail_shape = params_shape.slice(new Slice(start: 1));
var values_shape = array_ops.concat(new[] { indices_size, params_tail_shape }, 0);
var values = array_ops.reshape(grad, values_shape);
indices = array_ops.reshape(indices, indices_size);
return new Tensor[]
{
new IndexedSlices(values, indices, params_shape),
null,
null
};
}
return new Tensor[] { null, null };
}
[RegisterGradient("Reshape")]
public static Tensor[] _ReshapeGrad(Operation op, Tensor[] grads)
{
return new Tensor[] { array_ops.reshape(grads[0], array_ops.shape(op.inputs[0])), null };
}
[RegisterGradient("Pack")]
public static Tensor[] _PackGrad(Operation op, Tensor[] grads)
{
var grad = grads[0];
var num = op.get_attr<int>("N");
var axis = op.get_attr<int>("axis");
return array_ops.unstack(grad, num: num, axis: axis);
}
[RegisterGradient("Unpack")]
public static Tensor[] _UnpackGrad(Operation op, Tensor[] grads)
{
var axis = op.get_attr<int>("axis");
return new[] { array_ops.stack(grads, axis: axis) };
}
[RegisterGradient("Pad")]
public static Tensor[] _PadGrad(Operation op, Tensor[] grads)
{
var grad = grads[0];
var x = op.inputs[0];
var a = op.inputs[1];
var size = array_ops.stack(new Tensor[] { array_ops.rank(x), constant_op.constant(1) });
var begin = constant_op.constant(new[] { 0, 0 });
var pad_before = array_ops.slice(a, begin, size);
// Make it a 1-D tensor.
begin = array_ops.reshape(pad_before, new[] { -1 });
size = array_ops.shape(x);
var x_grad = array_ops.slice(grad, begin, size);
if (len(op.inputs) == 3)
return new Tensor[] { x_grad, null, null };
else
return new Tensor[] { x_grad, null };
}
[RegisterGradient("Split")]
public static Tensor[] _SplitGrad(Operation op, Tensor[] grads)
{
return new Tensor[] { null, array_ops.concat(list(grads), op.inputs[0]) };
}
[RegisterGradient("Slice")]
public static Tensor[] _SliceGrad(Operation op, Tensor[] grads)
{
var grad = grads[0];
var input_vec = op.inputs[0];
var begin_vec = op.inputs[1];
var input_rank = array_ops.rank(input_vec);
var slice_size = array_ops.shape(op.outputs[0]);
var shape = array_ops.stack(new Tensor[] { input_rank, ops.convert_to_tensor(1) });
var before_pad = array_ops.reshape(begin_vec, shape);
var after_pad = array_ops.reshape(array_ops.shape(input_vec) - slice_size - begin_vec, shape);
var paddings = array_ops.concat(new Tensor[] { before_pad, after_pad }, 1);
return new Tensor[]
{
array_ops.pad(grad, paddings),
null,
null
};
}
[RegisterGradient("Squeeze")]
public static Tensor[] _SqueezeGrad(Operation op, Tensor[] grads)
{
return new Tensor[] { _ReshapeToInput(op, grads[0]) };
}
[RegisterGradient("StopGradient")]
public static Tensor[] _NoGradient(Operation op, Tensor[] grads)
{
return new Tensor[] { null };
}
/// <summary>
/// Gradient for StridedSlice op.
/// </summary>
/// <param name="op"></param>
/// <param name="grads"></param>
/// <returns></returns>
[RegisterGradient("StridedSlice")]
public static Tensor[] _StridedSliceGrad(Operation op, Tensor[] grads)
{
var grad = grads[0];
var begin = op.inputs[1];
var end = op.inputs[2];
var strides = op.inputs[3];
var x = array_ops.shape(op.inputs[0], out_type: begin.dtype);
var x_static = tensor_util.constant_value(x);
var begin_static = tensor_util.constant_value(begin);
var end_static = tensor_util.constant_value(end);
var strides_static = tensor_util.constant_value(strides);
return new Tensor[]
{
array_ops.strided_slice_grad(
x_static,
begin_static,
end_static,
strides_static,
grad,
begin_mask: op.get_attr<long>("begin_mask"),
end_mask: op.get_attr<long>("end_mask"),
ellipsis_mask: op.get_attr<long>("ellipsis_mask"),
new_axis_mask: op.get_attr<long>("new_axis_mask"),
shrink_axis_mask: op.get_attr<long>("shrink_axis_mask")),
null,
null,
null
};
}
[RegisterGradient("StridedSliceGrad")]
public static Tensor[] _StridedSliceGradGrad(Operation op, Tensor[] grads)
{
var grad = grads[0];
var begin = op.inputs[1];
var end = op.inputs[2];
var strides = op.inputs[3];
return new Tensor[]
{
null,
null,
null,
gen_array_ops.strided_slice(
grad,
begin,
end,
strides,
begin_mask: op.get_attr<long>("begin_mask"),
end_mask: op.get_attr<long>("end_mask"),
ellipsis_mask: op.get_attr<long>("ellipsis_mask"),
new_axis_mask: op.get_attr<long>("new_axis_mask"),
shrink_axis_mask: op.get_attr<long>("shrink_axis_mask"))
};
}
private static Tensor _ReshapeToInput(Operation op, Tensor grad)
{
return array_ops.reshape(grad, array_ops.shape(op.inputs[0]));
}
[RegisterGradient("Transpose")]
public static Tensor[] _TransposeGrad(Operation op, Tensor[] grads)
{
var p = op.inputs[1];
return new Tensor[] { array_ops.transpose(grads[0], array_ops.invert_permutation(p)), null };
}
}
}