Skip to content

Commit bf9c1ad

Browse files
MPnoyOceania2018
authored andcommitted
GradientConcatTest
1 parent a653914 commit bf9c1ad

1 file changed

Lines changed: 10 additions & 8 deletions

File tree

test/TensorFlowNET.UnitTest/ManagedAPI/GradientTest.cs

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using Microsoft.VisualStudio.TestTools.UnitTesting;
2+
using NumSharp;
23
using Tensorflow;
34
using static Tensorflow.Binding;
45

@@ -65,15 +66,16 @@ public void GradientSliceTest()
6566
[TestMethod]
6667
public void GradientConcatTest()
6768
{
68-
var X = tf.zeros(10);
69-
var W = tf.Variable(-0.06f, name: "weight");
70-
var b = tf.Variable(-0.73f, name: "bias");
71-
var test = tf.concat(new Tensor[] { W, b }, 0);
69+
var w1 = tf.Variable(new[] { new[] { 1f } });
70+
var w2 = tf.Variable(new[] { new[] { 3f } });
7271
using var g = tf.GradientTape();
73-
var pred = test[0] * X + test[1];
74-
var gradients = g.gradient(pred, (W, b));
75-
Assert.IsNull(gradients.Item1);
76-
Assert.IsNull(gradients.Item2);
72+
var w = tf.concat(new Tensor[] { w1, w2 }, 0);
73+
var x = tf.ones((1, 2));
74+
var y = tf.reduce_sum(x, 1);
75+
var r = tf.matmul(w, x);
76+
var gradients = g.gradient(r, w);
77+
Assert.AreEqual((float)gradients[0][0], 2f);
78+
Assert.AreEqual((float)gradients[1][0], 2f);
7779
}
7880
}
7981
}

0 commit comments

Comments
 (0)