Skip to content

Commit e75ed4f

Browse files
committed
add array_grad.ConcatV2 gradient
1 parent 1da8138 commit e75ed4f

5 files changed

Lines changed: 161 additions & 30 deletions

File tree

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using System.Text;
5+
using Tensorflow.Operations;
6+
using static Tensorflow.Python;
7+
8+
namespace Tensorflow.Gradients
9+
{
10+
/// <summary>
11+
/// tensorflow\python\ops\array_grad.py
12+
/// </summary>
13+
public class array_grad
14+
{
15+
public static Tensor[] _ConcatGradV2(Operation op, Tensor[] grads)
16+
{
17+
var grad = grads[0];
18+
return _ConcatGradHelper(op, grad, start_value_index: 0, end_value_index: -1, dim_index: -1);
19+
}
20+
21+
/// <summary>
22+
/// Gradient for concat op.
23+
/// </summary>
24+
/// <param name="op">An operation.</param>
25+
/// <param name="grad">
26+
/// `Tensor` or `IndexedSlices` representing the gradients with respect
27+
/// to each output of the op.
28+
/// </param>
29+
/// <param name="start_value_index">An integer index of the first value in the op.inputs.</param>
30+
/// <param name="end_value_index">An integer index of the last value in the op.inputs.</param>
31+
/// <param name="dim_index">An interger index of concat_dim or axis parameter in op.inputs.</param>
32+
/// <returns>
33+
/// Tensors representing the partial gradients with respect to each input
34+
/// of the op.
35+
/// </returns>
36+
private static Tensor[] _ConcatGradHelper(Operation op, Tensor grad, int start_value_index, int end_value_index, int dim_index)
37+
{
38+
// Degenerate concatenation, just return grad.
39+
if (len(op.inputs) == 2)
40+
return end_value_index <= dim_index ? new Tensor[] { grad, null } : new Tensor[] { null, grad };
41+
42+
var concat_dim = op.inputs[dim_index];
43+
var input_values = op.inputs._inputs.Skip(start_value_index).Take(end_value_index - start_value_index).ToArray();
44+
45+
var out_grads = new List<Tensor>();
46+
if (constant_op.is_constant(concat_dim))
47+
{
48+
/*If concat_dim is a constant defined in a different context,
49+
then we duplicate it in the current context to avoid passing it
50+
through an Enter node.
51+
This is a small optimization in general, but it is required when
52+
compiling with XLA, as XLA needs the concat input to be folded into a
53+
constant.*/
54+
var grad_context = control_flow_util.GetOutputContext(grad.op);
55+
var dim_context = control_flow_util.GetOutputContext(concat_dim.op);
56+
if (dim_context != grad_context)
57+
{
58+
var value = tensor_util.constant_value(concat_dim);
59+
concat_dim = constant_op.constant(value: value, dtype: concat_dim.dtype);
60+
}
61+
}
62+
63+
// Using mod here for convenience since concat_dim is already verified
64+
// in concat implementation to be within the allowed [-rank, rank) range.
65+
var non_neg_concat_dim = concat_dim % array_ops.rank(input_values[0]);
66+
67+
// Get the inputs' tensor shapes
68+
var sizes = _ExtractInputShapes(input_values);
69+
70+
/* The magic number of 16 was found through benchmarking a range of sizes
71+
on CPUs and a Maxwell TitanX. A speedup was seen in a large majority of
72+
cases when switching implementations at N=16, but it is possible that
73+
there will be a small number of performance regressions.*/
74+
if (len(sizes) > 16)
75+
{
76+
// extract the size of each input along the concat dimension
77+
var slice = array_ops.slice(array_ops.stack(sizes, axis: 1),
78+
new Tensor[] { non_neg_concat_dim, tf.constant(0) },
79+
new Tensor[] { tf.constant(1), tf.constant(-1) });
80+
var squeeze_sizes = array_ops.squeeze(slice);
81+
out_grads = gen_ops.split(grad, squeeze_sizes, non_neg_concat_dim).ToList();
82+
}
83+
else
84+
{
85+
var offset = gen_ops.concat_offset(non_neg_concat_dim, sizes);
86+
foreach (var (begin, size) in zip(offset, sizes))
87+
out_grads.Add(gen_ops.slice(grad, begin, size));
88+
}
89+
90+
return (end_value_index <= dim_index ?
91+
out_grads.ToArray().Concat(null) :
92+
new Tensor[] { null }.Concat(out_grads)).ToArray();
93+
}
94+
95+
/// <summary>
96+
/// Extract the shapes of a set of input tensors.
97+
/// </summary>
98+
/// <param name="inputs"></param>
99+
/// <returns></returns>
100+
private static Tensor[] _ExtractInputShapes(Tensor[] inputs)
101+
{
102+
var sizes = new Tensor[inputs.Length];
103+
bool fully_known = true;
104+
for(int i = 0; i < inputs.Length; i++)
105+
{
106+
var x = inputs[i];
107+
108+
var input_shape = array_ops.shape(x);
109+
if (!(input_shape is Tensor) || input_shape.op.type != "Const")
110+
{
111+
fully_known = false;
112+
break;
113+
}
114+
115+
sizes[i] = input_shape;
116+
}
117+
118+
if (fully_known)
119+
return sizes;
120+
else
121+
return gen_ops.shape_n(inputs);
122+
}
123+
124+
125+
public static Tensor[] _ReshapeGrad(Operation op, Tensor[] grads)
126+
{
127+
return new Tensor[] { array_ops.reshape(grads[0], array_ops.shape(op.inputs[0])), null };
128+
}
129+
130+
public static Tensor[] _SqueezeGrad(Operation op, Tensor[] grads)
131+
{
132+
return new Tensor[] { _ReshapeToInput(op, grads[0]) };
133+
}
134+
135+
private static Tensor _ReshapeToInput(Operation op, Tensor grad)
136+
{
137+
return array_ops.reshape(grad, array_ops.shape(op.inputs[0]));
138+
}
139+
140+
public static Tensor[] _TransposeGrad(Operation op, Tensor[] grads)
141+
{
142+
var p = op.inputs[1];
143+
return new Tensor[] { array_ops.transpose(grads[0], array_ops.invert_permutation(p)), null };
144+
}
145+
}
146+
}

src/TensorFlowNET.Core/Gradients/array_grad.py.cs

Lines changed: 0 additions & 30 deletions
This file was deleted.

src/TensorFlowNET.Core/Gradients/nn_grad.py.cs renamed to src/TensorFlowNET.Core/Gradients/nn_grad.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66

77
namespace Tensorflow.Gradients
88
{
9+
/// <summary>
10+
///
11+
/// </summary>
912
public class nn_grad
1013
{
1114
/// <summary>

src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ public static Func<Operation, Tensor[], Tensor[]> get_gradient_function(Operatio
2222
return math_grad._AddGrad(oper, out_grads);
2323
case "BiasAdd":
2424
return nn_grad._BiasAddGrad(oper, out_grads);
25+
case "ConcatV2":
26+
return array_grad._ConcatGradV2(oper, out_grads);
2527
case "Exp":
2628
return math_grad._ExpGrad(oper, out_grads);
2729
case "Identity":

src/TensorFlowNET.Core/Tensors/constant_op.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,5 +88,15 @@ public static Tensor _tensor_shape_tensor_conversion_function(TensorShape s, TF_
8888

8989
return constant_op.constant(s_list, name: name);
9090
}
91+
92+
public static bool is_constant(ITensorOrOperation tensor_or_op)
93+
{
94+
if (tensor_or_op is Tensor tensor)
95+
return tensor.op.type == "Const";
96+
else if (tensor_or_op is Operation op)
97+
return op.type == "Const";
98+
else
99+
throw new ValueError("is_constant");
100+
}
91101
}
92102
}

0 commit comments

Comments
 (0)