Skip to content

Commit 9e414f4

Browse files
committed
add _FusedBatchNormGrad
1 parent 66f7e6d commit 9e414f4

10 files changed

Lines changed: 188 additions & 22 deletions

File tree

src/TensorFlowNET.Core/Gradients/control_flow_grad.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ public class control_flow_grad
3636
/// </summary>
3737
/// <returns></returns>
3838
[RegisterGradient("Switch")]
39-
public Tensor[] _SwitchGrad(Tensor op, Tensor[] grads)
39+
public Tensor[] _SwitchGrad(Operation op, Tensor[] grads)
4040
{
4141
throw new NotImplementedException("_SwitchGrad");
4242
//graph = ops.get_default_graph()

src/TensorFlowNET.Core/Gradients/gradients_util.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ public static Tensor[] _GradientsHelper(Tensor[] ys,
108108
{
109109
// generate gradient subgraph for op.
110110
var op = queue.Dequeue();
111-
if(tf.get_default_graph()._nodes_by_name.Count > 18505)
111+
if(tf.get_default_graph()._nodes_by_name.Count > 18577)
112112
{
113113

114114
}

src/TensorFlowNET.Core/Gradients/nn_grad.cs

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,94 @@ public static Tensor[] _Conv2DGrad(Operation op, Tensor[] grads)
166166
};
167167
}
168168

169+
[RegisterGradient("FusedBatchNorm")]
170+
public static Tensor[] _FusedBatchNormGrad(Operation op, Tensor[] grads)
171+
=> _BaseFusedBatchNormGrad(op, 0, grads);
172+
173+
/// <summary>
174+
/// Return the gradients for the 3 inputs of BatchNorm.
175+
/// </summary>
176+
/// <param name="op"></param>
177+
/// <param name="version"></param>
178+
/// <param name="grads"></param>
179+
/// <returns></returns>
180+
public static Tensor[] _BaseFusedBatchNormGrad(Operation op, int version, Tensor[] grads)
181+
{
182+
var x = op.inputs[0];
183+
var grad_y = grads[0];
184+
var scale = op.inputs[1];
185+
var epsilon = op.get_attr<float>("epsilon");
186+
var data_format = op.get_attr<string>("data_format");
187+
var is_training = op.get_attr<bool>("is_training");
188+
Func<FusedBatchNormParams, Tensor[]> grad_fun = null;
189+
190+
switch (version)
191+
{
192+
case 2:
193+
throw new NotImplementedException("");
194+
case 1:
195+
throw new NotImplementedException("");
196+
default:
197+
grad_fun = gen_nn_ops.fused_batch_norm_grad;
198+
break;
199+
}
200+
201+
if (is_training)
202+
{
203+
return grad_fun(new FusedBatchNormParams
204+
{
205+
YBackprop = grad_y,
206+
X = x,
207+
Scale = scale,
208+
ReserveSpace1 = op.outputs[3],
209+
ReserveSpace2 = op.outputs[4],
210+
ReserveSpace3 = version == 2 ? op.outputs[5] : null,
211+
Epsilon = epsilon,
212+
DataFormat = data_format,
213+
IsTraining = is_training
214+
});
215+
}
216+
else
217+
{
218+
var pop_mean = op.inputs[3];
219+
var pop_var = op.inputs[4];
220+
if (data_format == "NCHW")
221+
throw new NotImplementedException("");
222+
223+
var results = grad_fun(new FusedBatchNormParams
224+
{
225+
YBackprop = grad_y,
226+
X = x,
227+
Scale = scale,
228+
ReserveSpace1 = op.outputs[3],
229+
ReserveSpace2 = op.outputs[4],
230+
ReserveSpace3 = version == 2 ? op.outputs[5] : null,
231+
Epsilon = epsilon,
232+
DataFormat = data_format,
233+
IsTraining = is_training
234+
});
235+
236+
var (dx, dscale, doffset) = (results[0], results[1], results[2]);
237+
if (data_format == "NCHW")
238+
throw new NotImplementedException("");
239+
240+
return new Tensor[]
241+
{
242+
dx,
243+
dscale,
244+
doffset,
245+
null,
246+
null
247+
};
248+
}
249+
}
250+
251+
[RegisterGradient("BatchNormWithGlobalNormalization")]
252+
public static Tensor _BatchNormWithGlobalNormalizationGrad(Operation op, Tensor[] grads)
253+
{
254+
throw new NotImplementedException("BatchNormWithGlobalNormalization");
255+
}
256+
169257
private static bool IsZero(Tensor g)
170258
{
171259
if (new string[] { "ZerosLike", "Zeros" }.Contains(g.op.type))

src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,20 +27,6 @@ namespace Tensorflow.Operations
2727
/// </summary>
2828
public class CondContext : ControlFlowContext, IProtoBuf<CondContextDef, CondContext>
2929
{
30-
31-
32-
/// <summary>
33-
/// The boolean tensor for the cond predicate
34-
/// </summary>
35-
private Tensor _pred;
36-
37-
public Tensor pred => _pred;
38-
39-
/// <summary>
40-
/// 0 or 1 representing this branch
41-
/// </summary>
42-
private int _branch;
43-
4430
private Dictionary<string, Tensor> _external_values = new Dictionary<string, Tensor>();
4531

4632
/// <summary>

src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,19 @@ public abstract class ControlFlowContext : IObjectLife
4545
/// The predicate tensor in this branch
4646
/// </summary>
4747
protected Tensor _pivot;
48-
public Tensor pivot
49-
{
50-
get => _pivot;
51-
}
48+
public Tensor pivot => _pivot;
49+
50+
/// <summary>
51+
/// The boolean tensor for the cond predicate
52+
/// </summary>
53+
protected Tensor _pred;
54+
public Tensor pred => _pred;
55+
56+
/// <summary>
57+
/// 0 or 1 representing this branch
58+
/// </summary>
59+
protected int _branch;
60+
public int branch => _branch;
5261

5362
protected Stack<ControlFlowContext> _context_stack;
5463
protected ControlFlowContext _outer_context;
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Operations
6+
{
7+
public class FusedBatchNormParams
8+
{
9+
public string Name { get; set; }
10+
public Tensor YBackprop { get; set; }
11+
public Tensor X { get; set; }
12+
public Tensor Scale { get; set; }
13+
public Tensor ReserveSpace1 { get; set; }
14+
public Tensor ReserveSpace2 { get; set; }
15+
public Tensor ReserveSpace3 { get; set; }
16+
public float Epsilon { get; set; }
17+
public string DataFormat { get; set; }
18+
public bool IsTraining { get; set; }
19+
20+
public FusedBatchNormParams()
21+
{
22+
Epsilon = 0.0001f;
23+
DataFormat = "NHWC";
24+
IsTraining = true;
25+
}
26+
}
27+
}

src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,35 @@ public static Tensor elu(Tensor features, string name = "Elu")
156156
return op.output;
157157
}
158158

159+
/// <summary>
160+
/// Gradient for batch normalization.
161+
/// </summary>
162+
/// <param name="y_backprop"></param>
163+
/// <param name="x"></param>
164+
/// <param name="scale"></param>
165+
/// <param name="reserve_space_1"></param>
166+
/// <param name="reserve_space_2"></param>
167+
/// <param name="epsilon"></param>
168+
/// <param name="data_format"></param>
169+
/// <param name="is_training"></param>
170+
/// <param name="name"></param>
171+
/// <returns></returns>
172+
public static Tensor[] fused_batch_norm_grad(FusedBatchNormParams @params)
173+
{
174+
var op = _op_def_lib._apply_op_helper("FusedBatchNormGrad", name: @params.Name, args: new
175+
{
176+
y_backprop = @params.YBackprop,
177+
x = @params.X,
178+
scale = @params.Scale,
179+
reserve_space_1 = @params.ReserveSpace1,
180+
reserve_space_2 = @params.ReserveSpace2,
181+
epsilon = @params.Epsilon,
182+
data_format = @params.DataFormat,
183+
is_training = @params.IsTraining
184+
});
185+
return op.outputs;
186+
}
187+
159188
public static Tensor[] fused_batch_norm(Tensor x,
160189
Tensor scale,
161190
Tensor offset,

src/TensorFlowNET.Core/Operations/Operation.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,9 @@ private object[] _reconstruct_sequence_inputs(OpDef op_def, Tensor[] inputs, Map
218218
return grouped_inputs.ToArray();
219219
}
220220

221+
public T get_attr<T>(string name)
222+
=> (T)get_attr(name);
223+
221224
public object get_attr(string name)
222225
{
223226
AttrValue x = null;

src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -557,8 +557,31 @@ public static Tensor ZerosLikeOutsideLoop(Operation op, int index)
557557
throw new NotImplementedException("ZerosLikeOutsideLoop");
558558
return array_ops.zeros_like(val, optimize: false);
559559
}
560-
561-
throw new NotImplementedException("ZerosLikeOutsideLoop");
560+
else
561+
{
562+
var op_ctxt = op._get_control_flow_context();
563+
if(op_ctxt != null)
564+
{
565+
// We are in a cond context. Use a switch to create zeros only when needed.
566+
var pred = op_ctxt.pred;
567+
var branch = op_ctxt.branch;
568+
var switch_val = @switch(op.inputs[0], pred)[1 - branch];
569+
var pivot = array_ops.identity(switch_val);
570+
if (val.dtype == dtypes.resource)
571+
throw new NotImplementedException("");
572+
var zeros_shape = array_ops.shape_internal(switch_val, optimize: false);
573+
// Ensure ops created within array_ops.zeros are dominated by switch in
574+
// cond context.
575+
return tf_with(ops.control_dependencies(new[] { pivot }), delegate
576+
{
577+
return array_ops.zeros(zeros_shape, dtype: val.dtype);
578+
});
579+
}
580+
else
581+
{
582+
return array_ops.zeros_like(val, optimize: false);
583+
}
584+
}
562585
}
563586

564587
/// <summary>

src/TensorFlowNET.Core/Tensors/dtypes.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ public static class dtypes
3333
public static TF_DataType float32 = TF_DataType.TF_FLOAT; // is that float32?
3434
public static TF_DataType float16 = TF_DataType.TF_HALF;
3535
public static TF_DataType float64 = TF_DataType.TF_DOUBLE;
36+
public static TF_DataType resource = TF_DataType.TF_RESOURCE;
3637

3738
/// <summary>
3839
///

0 commit comments

Comments
 (0)