forked from SciSharp/TensorFlow.NET
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathGradientTest.cs
More file actions
105 lines (89 loc) · 3.14 KB
/
GradientTest.cs
File metadata and controls
105 lines (89 loc) · 3.14 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
using Microsoft.VisualStudio.TestTools.UnitTesting;
using NumSharp;
using System.Linq;
using Tensorflow;
using static Tensorflow.Binding;
namespace TensorFlowNET.UnitTest
{
[TestClass]
public class GradientTest
{
[TestMethod]
public void Gradients()
{
var graph = tf.Graph().as_default();
var a = tf.constant(0.0);
var b = 2.0 * a;
Assert.AreEqual(b.name, "mul:0");
Assert.AreEqual(b.op.inputs[0].name, "mul/x:0");
Assert.AreEqual(b.op.inputs[1].name, "Const:0");
var ys = a + b;
Assert.AreEqual(ys.name, "add:0");
Assert.AreEqual(ys.op.inputs[0].name, "Const:0");
Assert.AreEqual(ys.op.inputs[1].name, "mul:0");
var g = tf.gradients(ys, new Tensor[] { a, b }, stop_gradients: new Tensor[] { a, b });
Assert.AreEqual(g[0].name, "gradients/Fill:0");
Assert.AreEqual(g[1].name, "gradients/Fill:0");
}
[TestMethod]
public void Gradient2x()
{
var graph = tf.Graph().as_default();
using (var sess = tf.Session(graph))
{
var x = tf.constant(7.0f);
var y = x * x * tf.constant(0.1f);
var grad = tf.gradients(y, x);
Assert.AreEqual(grad[0].name, "gradients/AddN:0");
float r = sess.run(grad[0]);
Assert.AreEqual(r, 1.4f);
}
}
[TestMethod]
public void Gradient3x()
{
var graph = tf.Graph().as_default();
tf_with(tf.Session(graph), sess => {
var x = tf.constant(7.0f);
var y = x * x * x * tf.constant(0.1f);
var grad = tf.gradients(y, x);
Assert.AreEqual(grad[0].name, "gradients/AddN:0");
float r = sess.run(grad[0]);
Assert.AreEqual(r, 14.700001f);
});
}
[TestMethod]
public void StridedSlice()
{
var graph = tf.Graph().as_default();
var t = tf.constant(np.array(new int[,,]
{
{
{ 11, 12, 13 },
{ 21, 22, 23 }
},
{
{ 31, 32, 33 },
{ 41, 42, 43 }
},
{
{ 51, 52, 53 },
{ 61, 62, 63 }
}
}));
var slice = tf.strided_slice(t,
begin: new[] { 0, 0, 0 },
end: new[] { 3, 2, 3 },
strides: new[] { 2, 2, 2 });
var y = slice + slice;
var g = tf.gradients(y, new Tensor[] { slice, slice });
using (var sess = tf.Session(graph))
{
var r = sess.run(slice);
Assert.IsTrue(Enumerable.SequenceEqual(r.shape, new[] { 2, 1, 2 }));
Assert.IsTrue(Enumerable.SequenceEqual(r[0].GetData<int>(), new[] { 11, 13 }));
Assert.IsTrue(Enumerable.SequenceEqual(r[1].GetData<int>(), new[] { 51, 53 }));
}
}
}
}