@@ -72,10 +72,7 @@ public static (Tensor, Tensor) _SubGrad(Operation op, Tensor grad)
7272
7373 public static bool _ShapesFullySpecifiedAndEqual ( Tensor x , Tensor y , Tensor grad )
7474 {
75- if ( x . NDims == 0 && y . NDims == 0 && grad . NDims == 0 ) return true ;
76-
77- return string . Join ( "," , x . shape ) . Equals ( string . Join ( "," , y . shape ) ) &&
78- string . Join ( "," , x . shape ) . Equals ( string . Join ( "," , grad . shape ) ) ;
75+ return x . NDims == y . NDims && y . NDims == grad . NDims && x . NDims > - 1 ;
7976 }
8077
8178 public static ( Tensor , Tensor ) _SumGrad ( Operation op , Tensor grad )
@@ -110,14 +107,15 @@ public static (Tensor, Tensor) _RealDivGrad(Operation op, Tensor grad)
110107 x = math_ops . conj ( x ) ;
111108 y = math_ops . conj ( y ) ;
112109
113- var realdiv1 = gen_math_ops . real_div ( grad , y ) ;
114- var reduce_sum1 = math_ops . reduce_sum ( realdiv1 , rx ) ;
115- var realdiv2 = gen_math_ops . real_div ( - x , y ) ;
116- var realdiv3 = gen_math_ops . real_div ( realdiv2 , y ) ;
117- var mul = grad * realdiv3 ;
118- var reduce_sum2 = math_ops . reduce_sum ( mul , ry ) ;
110+ var realdiv1 = gen_math_ops . real_div ( - x , y ) ;
111+ var realdiv2 = gen_math_ops . real_div ( realdiv1 , y ) ;
112+ var reduce_sum1 = math_ops . reduce_sum ( grad * realdiv2 , ry ) ;
113+ var reshape1 = gen_array_ops . reshape ( reduce_sum1 , sy ) ;
114+ var realdiv3 = gen_math_ops . real_div ( grad , y ) ;
115+ var reduce_sum2 = math_ops . reduce_sum ( realdiv3 , rx ) ;
116+ var reshape2 = gen_array_ops . reshape ( reduce_sum2 , sx ) ;
119117
120- return ( gen_array_ops . reshape ( reduce_sum1 , sx ) , gen_array_ops . reshape ( reduce_sum2 , sy ) ) ;
118+ return ( reshape2 , reshape1 ) ;
121119 }
122120
123121 public static ( Tensor , Tensor ) _PowGrad ( Operation op , Tensor grad )
@@ -135,17 +133,16 @@ public static (Tensor, Tensor) _PowGrad(Operation op, Tensor grad)
135133 var gx = gen_array_ops . reshape ( math_ops . reduce_sum ( grad * y * gen_math_ops . pow ( x , y - 1.0 ) , rx ) , sx ) ;
136134 Tensor log_x = null ;
137135 // Avoid false singularity at x = 0
136+ Tensor mask = null ;
138137 if ( x . dtype . is_complex ( ) )
139- {
140138 throw new NotImplementedException ( "x.dtype.is_complex()" ) ;
141- }
142139 else
143- {
144- var x1 = gen_array_ops . log ( x ) ;
145- var y1 = array_ops . zeros_like ( x ) ;
146- log_x = array_ops . where ( x > 0.0 , x1 , y1 ) ;
147- }
148-
140+ mask = x > 0.0f ;
141+ var ones = array_ops . ones_like ( x ) ;
142+ var safe_x = array_ops . where ( mask , x , ones ) ;
143+ var x1 = gen_array_ops . log ( safe_x ) ;
144+ var y1 = array_ops . zeros_like ( x ) ;
145+ log_x = array_ops . where ( mask , x1 , y1 ) ;
149146 var gy = gen_array_ops . reshape ( math_ops . reduce_sum ( grad * z * log_x , ry ) , sy ) ;
150147
151148 return ( gx , gy ) ;
0 commit comments