@@ -211,14 +211,14 @@ public static Tensor reduce_all(Tensor input_tensor, int[] axis = null, bool kee
211211 /// <returns> The reduced tensor.</returns>
212212 public static Tensor reduce_logsumexp ( Tensor input_tensor , int [ ] axis = null , bool keepdims = false , string name = null )
213213 {
214- with ( ops . name_scope ( name , "ReduceLogSumExp" , new { input_tensor } ) , scope =>
214+ return with ( ops . name_scope ( name , "ReduceLogSumExp" , new { input_tensor } ) , scope =>
215215 {
216216 var raw_max = reduce_max ( input_tensor , axis , true ) ;
217217 var my_max = array_ops . stop_gradient ( array_ops . where ( gen_math_ops . is_finite ( raw_max ) , raw_max , array_ops . zeros_like ( raw_max ) ) ) ;
218218 var result = gen_math_ops . log (
219219 reduce_sum (
220220 gen_math_ops . exp ( gen_math_ops . sub ( input_tensor , my_max ) ) ,
221- new Tensor ( axis ) ,
221+ axis [ 0 ] ,
222222 keepdims ) ) ;
223223 if ( ! keepdims )
224224 {
@@ -227,7 +227,6 @@ public static Tensor reduce_logsumexp(Tensor input_tensor, int[] axis = null, bo
227227 result = gen_math_ops . add ( result , my_max ) ;
228228 return _may_reduce_to_scalar ( keepdims , axis , result ) ;
229229 } ) ;
230- return null ;
231230 }
232231
233232 public static Tensor reduce_max ( Tensor input_tensor , int [ ] axis = null , bool keepdims = false , string name = null )
@@ -284,13 +283,17 @@ private static Tensor _may_reduce_to_scalar(bool keepdims, Tensor axis, Tensor o
284283 if ( ! common_shapes . has_fully_defined_shape ( output ) &&
285284 ! keepdims &&
286285 axis == null )
286+ // We want set_shape to be reflected in the C API graph for when we run it.
287287 output . shape = new long [ 0 ] ;
288288 return output ;
289289 }
290290
291291 private static Tensor _may_reduce_to_scalar ( bool keepdims , int [ ] axis , Tensor output )
292292 {
293- output . shape = new long [ 0 ] ;
293+ if ( ! common_shapes . has_fully_defined_shape ( output ) &&
294+ ! keepdims &&
295+ axis == null )
296+ output . shape = new long [ 0 ] ;
294297 return output ;
295298 }
296299
@@ -312,7 +315,7 @@ private static Tensor _ReductionDims(Tensor x, int[] axis)
312315 if ( axis != null )
313316 {
314317 // should return axis. or check before.
315- return null ;
318+ return ops . convert_to_tensor ( axis , TF_DataType . TF_INT32 ) ;
316319 }
317320 else
318321 {
0 commit comments