Skip to content

Commit 242e051

Browse files
committed
add function test.
1 parent 66ba622 commit 242e051

8 files changed

Lines changed: 594 additions & 13 deletions

File tree

test/TensorFlowNET.UnitTest/EagerModeTestBase.cs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,9 @@
22
using System;
33
using System.Collections.Generic;
44
using System.Text;
5-
using TensorFlowNET.UnitTest;
65
using static Tensorflow.Binding;
76

8-
namespace Tensorflow.UnitTest
7+
namespace TensorFlowNET.UnitTest
98
{
109
public class EagerModeTestBase : PythonTest
1110
{

test/TensorFlowNET.UnitTest/ImageTest.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ namespace TensorFlowNET.UnitTest.Basics
1414
/// Find more examples in https://www.programcreek.com/python/example/90444/tensorflow.read_file
1515
/// </summary>
1616
[TestClass]
17-
public class ImageTest
17+
public class ImageTest : GraphModeTestBase
1818
{
1919
string imgPath = "shasta-daisy.jpg";
2020
Tensor contents;
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
using Microsoft.VisualStudio.TestTools.UnitTesting;
2+
using System;
3+
using System.Collections.Generic;
4+
using System.Linq.Expressions;
5+
using System.Runtime.CompilerServices;
6+
using System.Security.Cryptography.X509Certificates;
7+
using System.Text;
8+
using Tensorflow;
9+
using static Tensorflow.Binding;
10+
11+
namespace TensorFlowNET.UnitTest.ManagedAPI
12+
{
13+
[TestClass]
14+
public class ControlFlowApiTest
15+
{
16+
[TestMethod]
17+
public void WhileLoopOneInputEagerMode()
18+
{
19+
tf.enable_eager_execution();
20+
21+
var i = tf.constant(2);
22+
Func<Tensor, Tensor> c = (x) => tf.less(x, 10);
23+
Func<Tensor, Tensor> b = (x) => tf.add(x, 1);
24+
var r = tf.while_loop(c, b, i);
25+
Assert.AreEqual(10, (int)r);
26+
}
27+
28+
[TestMethod]
29+
public void WhileLoopTwoInputsEagerMode()
30+
{
31+
tf.enable_eager_execution();
32+
33+
var i = tf.constant(2);
34+
var j = tf.constant(3);
35+
Func<Tensor[], Tensor> c = (x) => tf.less(x[0] + x[1], 10);
36+
Func<Tensor[], Tensor[]> b = (x) => new[] { tf.add(x[0], 1), tf.add(x[1], 1) };
37+
var r = tf.while_loop(c, b, new[] { i, j });
38+
Assert.AreEqual(5, (int)r[0]);
39+
Assert.AreEqual(6, (int)r[1]);
40+
}
41+
42+
[TestMethod, Ignore]
43+
public void WhileLoopGraphMode()
44+
{
45+
tf.compat.v1.disable_eager_execution();
46+
47+
var i = tf.constant(2);
48+
Func<Tensor, Tensor> c = (x) => tf.less(x, 10);
49+
Func<Tensor, Tensor> b = (x) => tf.add(x, 1);
50+
var r = tf.while_loop(c, b, i);
51+
Assert.AreEqual(10, (int)r);
52+
}
53+
}
54+
}
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
using Microsoft.VisualStudio.TestTools.UnitTesting;
2+
using System;
3+
using System.Collections.Generic;
4+
using System.Linq;
5+
using System.Text;
6+
using Tensorflow;
7+
using Tensorflow.Graphs;
8+
using static Tensorflow.Binding;
9+
10+
namespace TensorFlowNET.UnitTest.ManagedAPI
11+
{
12+
[TestClass]
13+
public class FunctionApiTest : TFNetApiTest
14+
{
15+
[TestMethod]
16+
public void TwoInputs_OneOutput()
17+
{
18+
var func = tf.autograph.to_graph(Add);
19+
var a = tf.constant(1);
20+
var b = tf.constant(2);
21+
var output = func(a, b);
22+
Assert.AreEqual(3, (int)output);
23+
}
24+
25+
Tensor Add(Tensor a, Tensor b)
26+
{
27+
return a + b;
28+
}
29+
30+
[TestMethod]
31+
public void TwoInputs_OneOutput_Condition()
32+
{
33+
var func = tf.autograph.to_graph(Condition);
34+
var a = tf.constant(3);
35+
var b = tf.constant(2);
36+
var output = func(a, b);
37+
Assert.AreEqual(2, (int)output);
38+
}
39+
40+
Tensor Condition(Tensor a, Tensor b)
41+
{
42+
return tf.cond(a < b, a, b);
43+
}
44+
45+
[TestMethod]
46+
public void TwoInputs_OneOutput_Lambda()
47+
{
48+
var func = tf.autograph.to_graph((x, y) => x * y);
49+
var output = func(tf.constant(3), tf.constant(2));
50+
Assert.AreEqual(6, (int)output);
51+
}
52+
53+
[TestMethod]
54+
public void TwoInputs_OneOutput_WhileLoop()
55+
{
56+
var func = tf.autograph.to_graph((x, y) => x * y);
57+
var output = func(tf.constant(3), tf.constant(2));
58+
Assert.AreEqual(6, (int)output);
59+
}
60+
61+
Tensor WhileLoop()
62+
{
63+
var i = tf.constant(0);
64+
Func<Tensor, Tensor> c = i => tf.less(i, 10);
65+
Func<Tensor, Tensor> b = i => tf.add(i, 1);
66+
//var r = tf.(c, b, [i])
67+
throw new NotImplementedException("");
68+
}
69+
}
70+
}

0 commit comments

Comments
 (0)