Skip to content

Commit 5a414f0

Browse files
committed
add CumsumGrad, BroadcastToGrad
1 parent d986603 commit 5a414f0

7 files changed

Lines changed: 94 additions & 2 deletions

File tree

src/TensorFlowNET.Core/APIs/tf.array.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,16 @@ public Tensor batch_to_space_nd<T>(T input, int[] block_shape, int[,] crops, str
5353
public Tensor boolean_mask<T1, T2>(T1 tensor, T2 mask, string name = "boolean_mask", int axis = 0)
5454
=> array_ops.boolean_mask(tensor, mask, name: name, axis: axis);
5555

56+
/// <summary>
57+
/// Broadcast an array for a compatible shape.
58+
/// </summary>
59+
/// <param name="input"></param>
60+
/// <param name="shape"></param>
61+
/// <param name="name"></param>
62+
/// <returns></returns>
63+
public Tensor broadcast_to(Tensor input, TensorShape shape, string name = null)
64+
=> gen_array_ops.broadcast_to(input, shape, name: name);
65+
5666
public Tensor check_numerics(Tensor tensor, string message, string name = null)
5767
=> gen_array_ops.check_numerics(tensor, message, name: name);
5868

src/TensorFlowNET.Core/Gradients/array_grad.cs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,27 @@ namespace Tensorflow.Gradients
2727
[RegisterGradient("array_grad")]
2828
public class array_grad
2929
{
30+
[RegisterGradient("BroadcastTo")]
31+
public static Tensor[] _BroadcastToGrad(Operation op, Tensor[] grads)
32+
{
33+
var grad = grads[0];
34+
var input_value = op.inputs[0];
35+
var broadcast_shape = op.inputs[1];
36+
var input_value_shape = array_ops.shape(input_value);
37+
var (_, reduction_axes) = gen_array_ops.broadcast_gradient_args(broadcast_shape,
38+
input_value_shape);
39+
var updates_grad_reshaped = math_ops.reduce_sum(grad,
40+
axis: reduction_axes,
41+
keepdims: true);
42+
var updates_grad = array_ops.reshape(updates_grad_reshaped, input_value_shape);
43+
44+
return new Tensor[]
45+
{
46+
updates_grad,
47+
null
48+
};
49+
}
50+
3051
[RegisterGradient("ConcatV2")]
3152
public static Tensor[] _ConcatGradV2(Operation op, Tensor[] grads)
3253
{

src/TensorFlowNET.Core/Gradients/math_grad.cs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,20 @@ public static Tensor[] _AddGrad(Operation op, Tensor[] grads)
5858
return new Tensor[] { r1, r2 };
5959
}
6060

61+
[RegisterGradient("Cumsum")]
62+
public static Tensor[] _CumsumGrad(Operation op, Tensor[] grads)
63+
{
64+
var grad = grads[0];
65+
var axis = op.inputs[1];
66+
var exclusive = op.get_attr<bool>("exclusive");
67+
var reverse = op.get_attr<bool>("reverse");
68+
return new Tensor[]
69+
{
70+
math_ops.cumsum(grad, axis, exclusive: exclusive, reverse: !reverse),
71+
null
72+
};
73+
}
74+
6175
[RegisterGradient("DivNoNan")]
6276
public static Tensor[] _DivNoNanGrad(Operation op, Tensor[] grads)
6377
{

src/TensorFlowNET.Core/Operations/gen_array_ops.cs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -515,5 +515,19 @@ public static Tensor broadcast_args(Tensor s0, Tensor s1, string name = null)
515515

516516
return _op.outputs[0];
517517
}
518+
519+
/// <summary>
520+
/// Broadcast an array for a compatible shape.
521+
/// </summary>
522+
/// <param name="input"></param>
523+
/// <param name="shape"></param>
524+
/// <param name="name"></param>
525+
/// <returns></returns>
526+
public static Tensor broadcast_to(Tensor input, int[] shape, string name = null)
527+
{
528+
var _op = _op_def_lib._apply_op_helper("BroadcastTo", name, args: new { input, shape, name });
529+
530+
return _op.outputs[0];
531+
}
518532
}
519533
}

src/TensorFlowNET.Core/Operations/gen_math_ops.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ public static Tensor cosh(Tensor x, string name = null)
238238
return _op.outputs[0];
239239
}
240240

241-
public static Tensor cumsum(Tensor x, int axis = 0, bool exclusive = false, bool reverse = false, string name = null)
241+
public static Tensor cumsum<T>(Tensor x, T axis, bool exclusive = false, bool reverse = false, string name = null)
242242
{
243243
var _op = _op_def_lib._apply_op_helper("Cumsum", name, args: new { x, axis, exclusive, reverse });
244244

src/TensorFlowNET.Core/Operations/math_ops.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ public static Tensor cast(Tensor x, TF_DataType dtype = TF_DataType.DtInvalid, s
8080
});
8181
}
8282

83-
public static Tensor cumsum(Tensor x, int axis = 0, bool exclusive = false, bool reverse = false, string name = null)
83+
public static Tensor cumsum<T>(Tensor x, T axis = default, bool exclusive = false, bool reverse = false, string name = null)
8484
{
8585
return tf_with(ops.name_scope(name, "Cumsum", new {x}), scope =>
8686
{

test/TensorFlowNET.UnitTest/gradients_test/GradientsTest.cs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,39 @@ namespace TensorFlowNET.UnitTest.gradients_test
1111
[TestClass]
1212
public class GradientsTest : PythonTest
1313
{
14+
[TestMethod]
15+
public void BroadcastToGrad()
16+
{
17+
var graph = tf.Graph().as_default();
18+
19+
var x = tf.constant(2, dtype: dtypes.float32);
20+
var y = tf.broadcast_to(x, (2, 4, 3));
21+
var grad = tf.gradients(y, x);
22+
23+
using (var sess = tf.Session(graph))
24+
{
25+
float result = sess.run(grad[0]);
26+
Assert.AreEqual(result, 24.0f);
27+
}
28+
}
29+
30+
[TestMethod]
31+
public void CumsumGrad()
32+
{
33+
var graph = tf.Graph().as_default();
34+
35+
var x = tf.constant(2, dtype: dtypes.float32);
36+
var y = tf.broadcast_to(x, (2, 4, 3));
37+
var z = tf.cumsum(y, axis: 1);
38+
var grad = tf.gradients(z, x);
39+
40+
using (var sess = tf.Session(graph))
41+
{
42+
float result = sess.run(grad[0]);
43+
Assert.AreEqual(result, 60.0f);
44+
}
45+
}
46+
1447
[Ignore("TODO")]
1548
[TestMethod]
1649
public void testGradients()

0 commit comments

Comments
 (0)