@@ -17,6 +17,8 @@ limitations under the License.
1717using NumSharp ;
1818using System ;
1919using System . Collections . Generic ;
20+ using System . Linq ;
21+ using Tensorflow . Framework ;
2022using static Tensorflow . Binding ;
2123
2224namespace Tensorflow
@@ -66,6 +68,44 @@ public static Tensor zeros(TensorShape shape, TF_DataType dtype = TF_DataType.TF
6668 } ) ;
6769 }
6870
71+ public static Tensor boolean_mask < T1 , T2 > ( T1 tensor , T2 mask , string name = "boolean_mask" , int axis = 0 )
72+ {
73+ return tf_with ( ops . name_scope ( name , values : new { tensor , mask } ) , delegate
74+ {
75+ var tensor_tensor = ops . convert_to_tensor ( tensor , name : "tensor" ) ;
76+ var mask_tensor = ops . convert_to_tensor ( mask , name : "mask" ) ;
77+
78+ var shape_mask = mask_tensor . TensorShape ;
79+ var ndims_mask = shape_mask . ndim ;
80+ var shape_tensor = tensor_tensor . TensorShape ;
81+
82+ if ( ndims_mask < 1 )
83+ throw new ValueError ( "mask cannot be scalar." ) ;
84+
85+ var leading_size = gen_math_ops . prod ( shape ( tensor_tensor ) [ $ "{ axis } :{ axis + ndims_mask } "] , new [ ] { 0 } ) ;
86+ var shape1 = concat ( new [ ]
87+ {
88+ shape ( tensor_tensor ) [ $ ":{ axis } "] ,
89+ tf . expand_dims ( leading_size , 0 ) ,
90+ shape ( tensor_tensor ) [ $ "{ axis + ndims_mask } :"]
91+ } , 0 ) ;
92+ tensor_tensor = reshape ( tensor , shape1 ) ;
93+ var first_dim = shape_tensor . dims . Skip ( axis ) . Take ( ndims_mask ) . First ( ) ;
94+ var s1 = tensor_shape . as_shape ( shape_tensor . dims . Take ( axis ) . ToArray ( ) ) ;
95+ var s2 = s1 . concatenate ( new [ ] { first_dim } ) . concatenate ( shape_tensor . dims . Skip ( axis + ndims_mask ) . ToArray ( ) ) ;
96+ tensor_tensor . set_shape ( s2 ) ;
97+
98+ mask_tensor = reshape ( mask_tensor , new [ ] { - 1 } ) ;
99+ return _apply_mask_1d ( tensor_tensor , mask_tensor , axis ) ;
100+ } ) ;
101+ }
102+
103+ private static Tensor _apply_mask_1d ( Tensor reshaped_tensor , Tensor mask , int axis = 0 )
104+ {
105+ var indices = squeeze ( where ( mask ) , axis : new [ ] { 1 } ) ;
106+ return gather ( reshaped_tensor , indices , axis : axis ) ;
107+ }
108+
69109 public static Tensor zeros ( Tensor shape , TF_DataType dtype = TF_DataType . TF_FLOAT , string name = null )
70110 {
71111 dtype = dtype . as_base_dtype ( ) ;
@@ -336,7 +376,12 @@ public static Tensor where(Tensor condition, object x = null, object y = null, s
336376 {
337377 if ( x == null && y == null )
338378 {
339- throw new NotImplementedException ( "where" ) ;
379+ return tf_with ( ops . name_scope ( name , "Where" , new { condition } ) , scope =>
380+ {
381+ name = scope ;
382+ condition = ops . convert_to_tensor ( condition , preferred_dtype : dtypes . @bool , name : "condition" ) ;
383+ return gen_array_ops . where ( condition : condition , name : name ) ;
384+ } ) ;
340385 }
341386 else if ( x != null && y != null )
342387 {
0 commit comments