Skip to content

Commit f1a3881

Browse files
committed
tf.boolean_mask SciSharp#396
1 parent 6c6c8c4 commit f1a3881

8 files changed

Lines changed: 84 additions & 3 deletions

File tree

src/TensorFlowNET.Core/APIs/tf.array.cs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,19 @@ public partial class tensorflow
3939
public Tensor batch_to_space_nd<T>(T input, int[] block_shape, int[,] crops, string name = null)
4040
=> gen_array_ops.batch_to_space_nd(input, block_shape, crops, name: name);
4141

42+
/// <summary>
43+
/// Apply boolean mask to tensor.
44+
/// </summary>
45+
/// <typeparam name="T1"></typeparam>
46+
/// <typeparam name="T2"></typeparam>
47+
/// <param name="tensor">N-D tensor.</param>
48+
/// <param name="mask">K-D boolean tensor, K <= N and K must be known statically.</param>
49+
/// <param name="name"></param>
50+
/// <param name="axis">A 0-D int Tensor representing the axis in tensor to mask from. </param>
51+
/// <returns>(N-K+1)-dimensional tensor populated by entries in tensor corresponding to True values in mask.</returns>
52+
public Tensor boolean_mask<T1, T2>(T1 tensor, T2 mask, string name = "boolean_mask", int axis = 0)
53+
=> array_ops.boolean_mask(tensor, mask, name: name, axis: axis);
54+
4255
public Tensor check_numerics(Tensor tensor, string message, string name = null)
4356
=> gen_array_ops.check_numerics(tensor, message, name: name);
4457

File renamed without changes.

src/TensorFlowNET.Core/APIs/tf.ops.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ public partial class tensorflow
2121
public Tensor assign(Tensor @ref, object value, bool validate_shape = true, bool use_locking = true, string name = null)
2222
=> state_ops.assign(@ref, value, validate_shape, use_locking, name);
2323

24+
public void device(string device_name)
25+
=> get_default_graph().device(device_name);
26+
2427
public object get_collection(string key, string scope = "")
2528
=> get_default_graph().get_collection(key, scope: scope);
2629

src/TensorFlowNET.Core/Graphs/Graph.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,11 @@ public Operation create_op(string op_type, Tensor[] inputs, TF_DataType[] dtypes
288288
return op;
289289
}
290290

291+
public void device(string device_name)
292+
{
293+
throw new NotImplementedException("");
294+
}
295+
291296
private void _create_op_helper(Operation op, bool compute_device = true)
292297
{
293298
_record_op_seen_by_control_dependencies(op);

src/TensorFlowNET.Core/Operations/array_ops.py.cs

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ limitations under the License.
1717
using NumSharp;
1818
using System;
1919
using System.Collections.Generic;
20+
using System.Linq;
21+
using Tensorflow.Framework;
2022
using static Tensorflow.Binding;
2123

2224
namespace Tensorflow
@@ -66,6 +68,44 @@ public static Tensor zeros(TensorShape shape, TF_DataType dtype = TF_DataType.TF
6668
});
6769
}
6870

71+
public static Tensor boolean_mask<T1, T2>(T1 tensor, T2 mask, string name = "boolean_mask", int axis = 0)
72+
{
73+
return tf_with(ops.name_scope(name, values: new { tensor, mask }), delegate
74+
{
75+
var tensor_tensor = ops.convert_to_tensor(tensor, name: "tensor");
76+
var mask_tensor = ops.convert_to_tensor(mask, name: "mask");
77+
78+
var shape_mask = mask_tensor.TensorShape;
79+
var ndims_mask = shape_mask.ndim;
80+
var shape_tensor = tensor_tensor.TensorShape;
81+
82+
if (ndims_mask < 1)
83+
throw new ValueError("mask cannot be scalar.");
84+
85+
var leading_size = gen_math_ops.prod(shape(tensor_tensor)[$"{axis}:{axis + ndims_mask}"], new[] { 0 });
86+
var shape1 = concat(new[]
87+
{
88+
shape(tensor_tensor)[$":{axis}"],
89+
tf.expand_dims(leading_size, 0),
90+
shape(tensor_tensor)[$"{axis + ndims_mask}:"]
91+
}, 0);
92+
tensor_tensor = reshape(tensor, shape1);
93+
var first_dim = shape_tensor.dims.Skip(axis).Take(ndims_mask).First();
94+
var s1 = tensor_shape.as_shape(shape_tensor.dims.Take(axis).ToArray());
95+
var s2 = s1.concatenate(new[] { first_dim }).concatenate(shape_tensor.dims.Skip(axis + ndims_mask).ToArray());
96+
tensor_tensor.set_shape(s2);
97+
98+
mask_tensor = reshape(mask_tensor, new[] { -1 });
99+
return _apply_mask_1d(tensor_tensor, mask_tensor, axis);
100+
});
101+
}
102+
103+
private static Tensor _apply_mask_1d(Tensor reshaped_tensor, Tensor mask, int axis = 0)
104+
{
105+
var indices = squeeze(where(mask), axis: new[] { 1 });
106+
return gather(reshaped_tensor, indices, axis: axis);
107+
}
108+
69109
public static Tensor zeros(Tensor shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null)
70110
{
71111
dtype = dtype.as_base_dtype();
@@ -336,7 +376,12 @@ public static Tensor where(Tensor condition, object x = null, object y = null, s
336376
{
337377
if( x == null && y == null)
338378
{
339-
throw new NotImplementedException("where");
379+
return tf_with(ops.name_scope(name, "Where", new { condition }), scope =>
380+
{
381+
name = scope;
382+
condition = ops.convert_to_tensor(condition, preferred_dtype: dtypes.@bool, name: "condition");
383+
return gen_array_ops.where(condition: condition, name: name);
384+
});
340385
}
341386
else if(x != null && y != null)
342387
{

src/TensorFlowNET.Core/Operations/gen_array_ops.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -274,9 +274,10 @@ public static Tensor[] unpack(Tensor value, int num, int axis = 0, string name =
274274
return _op.outputs;
275275
}
276276

277-
public static Tensor where()
277+
public static Tensor where(Tensor condition, string name = null)
278278
{
279-
throw new NotImplementedException("where");
279+
var _op = _op_def_lib._apply_op_helper("Where", name, new { input = condition });
280+
return _op.output;
280281
}
281282

282283
public static Tensor one_hot(Tensor indices, int depth,

src/TensorFlowNET.Core/Tensors/dtypes.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ namespace Tensorflow
2323
{
2424
public static class dtypes
2525
{
26+
public static TF_DataType @bool = TF_DataType.TF_BOOL;
2627
public static TF_DataType int8 = TF_DataType.TF_INT8;
2728
public static TF_DataType int32 = TF_DataType.TF_INT32;
2829
public static TF_DataType int64 = TF_DataType.TF_INT64;

test/TensorFlowNET.UnitTest/TensorTest.cs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,5 +260,18 @@ public void batch_to_space_nd()
260260
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 15, 21, 16, 22, 17, 23 }, result[0, 3].ToArray<int>()));
261261
}
262262
}
263+
264+
[TestMethod]
265+
public void boolean_mask()
266+
{
267+
var tensor = new[] { 0, 1, 2, 3 };
268+
var mask = np.array(new[] { true, false, true, false });
269+
var masked = tf.boolean_mask(tensor, mask);
270+
using (var sess = tf.Session())
271+
{
272+
var result = sess.run(masked);
273+
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 2 }, result.ToArray<int>()));
274+
}
275+
}
263276
}
264277
}

0 commit comments

Comments
 (0)