Skip to content

Commit f55650d

Browse files
committed
Fix gradient of squared_difference SciSharp#787
1 parent 8d0bd50 commit f55650d

4 files changed

Lines changed: 8 additions & 5 deletions

File tree

src/TensorFlowNET.Core/Gradients/nn_grad.cs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,13 +122,14 @@ public static Tensor[] _SparseSoftmaxCrossEntropyWithLogitsGrad(Operation op, Te
122122
[RegisterGradient("SquaredDifference")]
123123
public static Tensor[] _SquaredDifferenceGrad(Operation op, Tensor[] grads)
124124
{
125-
//"""Returns the gradient for (x-y)^2."""
126125
Tensor x = op.inputs[0];
127126
Tensor y = op.inputs[1];
127+
var scale = ops.convert_to_tensor(2.0f, dtype: x.dtype);
128+
var x_grad = math_ops.scalar_mul(scale, grads[0]) * (x - y);
128129
return new Tensor[]
129130
{
130-
x,
131-
y
131+
x_grad,
132+
-x_grad
132133
};
133134
}
134135
/// <summary>

src/TensorFlowNET.Core/Operations/math_ops.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,9 @@ public static Tensor not_equal<Tx, Ty>(Tx x, Ty y, string name = null)
272272
public static Tensor mul_no_nan<Tx, Ty>(Tx x, Ty y, string name = null)
273273
=> gen_math_ops.mul_no_nan(x, y, name: name);
274274

275+
public static Tensor scalar_mul<Tscale, Tx>(Tscale scale, Tx x, string name = null)
276+
=> tf.Context.ExecuteOp("Mul", name, new ExecuteOpArgs(scale, x));
277+
275278
public static Tensor real(Tensor input, string name = null)
276279
{
277280
return tf_with(ops.name_scope(name, "Real", new[] { input }), scope =>

src/TensorFlowNET.Core/tensorflow.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ public partial class tensorflow : ITensorFlowObject
4848
public tensorflow()
4949
{
5050
Logger = new LoggerConfiguration()
51-
.MinimumLevel.Error()
51+
.MinimumLevel.Debug()
5252
.WriteTo.Console()
5353
.CreateLogger();
5454

test/TensorFlowNET.UnitTest/GradientTest/GradientEagerTest.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ public void ConstantSquare()
2525
Assert.AreEqual((float)grad, 3.0f);
2626
}
2727

28-
[Ignore]
2928
[TestMethod]
3029
public void SquaredDifference_Constant()
3130
{

0 commit comments

Comments
 (0)