@@ -168,6 +168,96 @@ public static Tensor[] _MeanGrad(Operation op, Tensor[] grads)
168168 return new Tensor [ ] { math_ops . truediv ( sum_grad , math_ops . cast ( factor , sum_grad . dtype ) ) , null } ;
169169 }
170170
171+ /// <summary>
172+ /// Gradient for Max.
173+ /// </summary>
174+ /// <param name="op"></param>
175+ /// <param name="grads"></param>
176+ /// <returns></returns>
177+ [ RegisterGradient ( "Max" ) ]
178+ public static Tensor [ ] _MaxGrad ( Operation op , Tensor [ ] grads )
179+ {
180+ return _MinOrMaxGrad ( op , grads ) ;
181+ }
182+
183+ /// <summary>
184+ /// Gradient for Min.
185+ /// </summary>
186+ /// <param name="op"></param>
187+ /// <param name="grads"></param>
188+ /// <returns></returns>
189+ [ RegisterGradient ( "Min" ) ]
190+ public static Tensor [ ] _MinGrad ( Operation op , Tensor [ ] grads )
191+ {
192+ return _MinOrMaxGrad ( op , grads ) ;
193+ }
194+
195+ private static Tensor [ ] _MinOrMaxGrad ( Operation op , Tensor [ ] grads )
196+ {
197+ var grad = grads [ 0 ] ;
198+ var input_shape = array_ops . shape ( op . inputs [ 0 ] ) ;
199+ var output_shape_kept_dims = math_ops . reduced_shape ( input_shape , op . inputs [ 1 ] ) ;
200+ var y = op . outputs [ 0 ] ;
201+ y = array_ops . reshape ( y , output_shape_kept_dims ) ;
202+ grad = array_ops . reshape ( grad , output_shape_kept_dims ) ;
203+
204+ // Compute the number of selected (maximum or minimum) elements in each
205+ // reduction dimension. If there are multiple minimum or maximum elements
206+ // then the gradient will be divided between them.
207+ var indicators = math_ops . cast ( math_ops . equal ( y , op . inputs [ 0 ] ) , grad . dtype ) ;
208+ var num_selected = array_ops . reshape ( math_ops . reduce_sum ( indicators , op . inputs [ 1 ] ) , output_shape_kept_dims ) ;
209+
210+ return new Tensor [ ] { math_ops . div ( indicators , num_selected ) * grad , null } ;
211+ }
212+
213+ /// <summary>
214+ /// Returns grad*(x > y, x <= y) with type of grad.
215+ /// </summary>
216+ /// <param name="op"></param>
217+ /// <param name="grads"></param>
218+ /// <returns></returns>
219+ [ RegisterGradient ( "Maximum" ) ]
220+ public static Tensor [ ] _MaximumGrad ( Operation op , Tensor [ ] grads )
221+ {
222+ return _MaximumMinimumGrad ( op , grads [ 0 ] ) ;
223+ }
224+
225+ /// <summary>
226+ /// Returns grad*(x < y, x >= y) with type of grad.
227+ /// </summary>
228+ /// <param name="op"></param>
229+ /// <param name="grads"></param>
230+ /// <returns></returns>
231+ [ RegisterGradient ( "Minimum" ) ]
232+ public static Tensor [ ] _MinimumGrad ( Operation op , Tensor [ ] grads )
233+ {
234+ return _MaximumMinimumGrad ( op , grads [ 0 ] ) ;
235+ }
236+
237+ /// <summary>
238+ /// Factor out the code for the gradient of Maximum or Minimum.
239+ /// </summary>
240+ /// <param name="op"></param>
241+ /// <param name="grad"></param>
242+ /// <returns></returns>
243+ private static Tensor [ ] _MaximumMinimumGrad ( Operation op , Tensor grad )
244+ {
245+ var x = op . inputs [ 0 ] ;
246+ var y = op . inputs [ 1 ] ;
247+ var gdtype = grad . dtype ;
248+ var sx = array_ops . shape ( x ) ;
249+ var sy = array_ops . shape ( y ) ;
250+ var gradshape = array_ops . shape ( grad ) ;
251+ var zeros = array_ops . zeros ( gradshape , gdtype ) ;
252+ var xmask = gen_math_ops . greater_equal ( x , y ) ;
253+ var ( rx , ry ) = gen_array_ops . broadcast_gradient_args ( sx , sy ) ;
254+ var xgrad = array_ops . where ( xmask , grad , zeros ) ;
255+ var ygrad = array_ops . where ( xmask , zeros , grad ) ;
256+ var gx = array_ops . reshape ( math_ops . reduce_sum ( xgrad , rx ) , sx ) ;
257+ var gy = array_ops . reshape ( math_ops . reduce_sum ( ygrad , ry ) , sy ) ;
258+ return new Tensor [ ] { gx , gy } ;
259+ }
260+
171261 [ RegisterGradient ( "Neg" ) ]
172262 public static Tensor [ ] _NegGrad ( Operation op , Tensor [ ] grads )
173263 {
0 commit comments