@@ -166,6 +166,94 @@ public static Tensor[] _Conv2DGrad(Operation op, Tensor[] grads)
166166 } ;
167167 }
168168
169+ [ RegisterGradient ( "FusedBatchNorm" ) ]
170+ public static Tensor [ ] _FusedBatchNormGrad ( Operation op , Tensor [ ] grads )
171+ => _BaseFusedBatchNormGrad ( op , 0 , grads ) ;
172+
173+ /// <summary>
174+ /// Return the gradients for the 3 inputs of BatchNorm.
175+ /// </summary>
176+ /// <param name="op"></param>
177+ /// <param name="version"></param>
178+ /// <param name="grads"></param>
179+ /// <returns></returns>
180+ public static Tensor [ ] _BaseFusedBatchNormGrad ( Operation op , int version , Tensor [ ] grads )
181+ {
182+ var x = op . inputs [ 0 ] ;
183+ var grad_y = grads [ 0 ] ;
184+ var scale = op . inputs [ 1 ] ;
185+ var epsilon = op . get_attr < float > ( "epsilon" ) ;
186+ var data_format = op . get_attr < string > ( "data_format" ) ;
187+ var is_training = op . get_attr < bool > ( "is_training" ) ;
188+ Func < FusedBatchNormParams , Tensor [ ] > grad_fun = null ;
189+
190+ switch ( version )
191+ {
192+ case 2 :
193+ throw new NotImplementedException ( "" ) ;
194+ case 1 :
195+ throw new NotImplementedException ( "" ) ;
196+ default :
197+ grad_fun = gen_nn_ops . fused_batch_norm_grad ;
198+ break ;
199+ }
200+
201+ if ( is_training )
202+ {
203+ return grad_fun ( new FusedBatchNormParams
204+ {
205+ YBackprop = grad_y ,
206+ X = x ,
207+ Scale = scale ,
208+ ReserveSpace1 = op . outputs [ 3 ] ,
209+ ReserveSpace2 = op . outputs [ 4 ] ,
210+ ReserveSpace3 = version == 2 ? op . outputs [ 5 ] : null ,
211+ Epsilon = epsilon ,
212+ DataFormat = data_format ,
213+ IsTraining = is_training
214+ } ) ;
215+ }
216+ else
217+ {
218+ var pop_mean = op . inputs [ 3 ] ;
219+ var pop_var = op . inputs [ 4 ] ;
220+ if ( data_format == "NCHW" )
221+ throw new NotImplementedException ( "" ) ;
222+
223+ var results = grad_fun ( new FusedBatchNormParams
224+ {
225+ YBackprop = grad_y ,
226+ X = x ,
227+ Scale = scale ,
228+ ReserveSpace1 = op . outputs [ 3 ] ,
229+ ReserveSpace2 = op . outputs [ 4 ] ,
230+ ReserveSpace3 = version == 2 ? op . outputs [ 5 ] : null ,
231+ Epsilon = epsilon ,
232+ DataFormat = data_format ,
233+ IsTraining = is_training
234+ } ) ;
235+
236+ var ( dx , dscale , doffset ) = ( results [ 0 ] , results [ 1 ] , results [ 2 ] ) ;
237+ if ( data_format == "NCHW" )
238+ throw new NotImplementedException ( "" ) ;
239+
240+ return new Tensor [ ]
241+ {
242+ dx ,
243+ dscale ,
244+ doffset ,
245+ null ,
246+ null
247+ } ;
248+ }
249+ }
250+
251+ [ RegisterGradient ( "BatchNormWithGlobalNormalization" ) ]
252+ public static Tensor _BatchNormWithGlobalNormalizationGrad ( Operation op , Tensor [ ] grads )
253+ {
254+ throw new NotImplementedException ( "BatchNormWithGlobalNormalization" ) ;
255+ }
256+
169257 private static bool IsZero ( Tensor g )
170258 {
171259 if ( new string [ ] { "ZerosLike" , "Zeros" } . Contains ( g . op . type ) )
0 commit comments