Skip to content

Commit 4d8ae9a

Browse files
committed
Added unit-tests for autocasting mechanism.
1 parent 00830b8 commit 4d8ae9a

1 file changed

Lines changed: 57 additions & 0 deletions

File tree

test/TensorFlowNET.UnitTest/SessionTest.cs

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
using System.Text;
88
using FluentAssertions;
99
using Google.Protobuf;
10+
using NumSharp.Backends;
1011
using Tensorflow;
1112
using Tensorflow.Util;
1213
using static Tensorflow.Binding;
@@ -131,5 +132,61 @@ public void Eval_LargeString_Scalar()
131132
}
132133
}
133134
}
135+
136+
[TestMethod]
137+
public void Autocast_Case1()
138+
{
139+
var sess = tf.Session().as_default();
140+
var input = tf.placeholder(tf.float64, shape: new TensorShape(6));
141+
var op = tf.reshape(input, new int[] {2, 3});
142+
sess.run(tf.global_variables_initializer());
143+
var ret = sess.run(op, feed_dict: (input, np.array(1, 2, 3, 4, 5, 6)));
144+
145+
ret.Should().BeOfType<double>().And.BeShaped(2, 3).And.BeOfValues(1, 2, 3, 4, 5, 6);
146+
print(ret.dtype);
147+
print(ret);
148+
}
149+
150+
[TestMethod]
151+
public void Autocast_Case2()
152+
{
153+
var sess = tf.Session().as_default();
154+
var input = tf.placeholder(tf.float64, shape: new TensorShape(6));
155+
var op = tf.reshape(input, new int[] {2, 3});
156+
sess.run(tf.global_variables_initializer());
157+
var ret = sess.run(op, feed_dict: (input, np.array(1, 2, 3, 4, 5, 6).astype(NPTypeCode.Single) + 0.1f));
158+
159+
ret.Should().BeOfType<double>().And.BeShaped(2, 3).And.BeOfValuesApproximately(0.001d, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1);
160+
print(ret.dtype);
161+
print(ret);
162+
}
163+
164+
[TestMethod]
165+
public void Autocast_Case3()
166+
{
167+
var sess = tf.Session().as_default();
168+
var input = tf.placeholder(tf.int16, shape: new TensorShape(6));
169+
var op = tf.reshape(input, new int[] {2, 3});
170+
sess.run(tf.global_variables_initializer());
171+
var ret = sess.run(op, feed_dict: (input, np.array(1, 2, 3, 4, 5, 6).astype(NPTypeCode.Single) + 0.1f));
172+
173+
ret.Should().BeOfType<short>().And.BeShaped(2, 3).And.BeOfValues(1, 2, 3, 4, 5, 6);
174+
print(ret.dtype);
175+
print(ret);
176+
}
177+
178+
[TestMethod]
179+
public void Autocast_Case4()
180+
{
181+
var sess = tf.Session().as_default();
182+
var input = tf.placeholder(tf.@byte, shape: new TensorShape(6));
183+
var op = tf.reshape(input, new int[] {2, 3});
184+
sess.run(tf.global_variables_initializer());
185+
var ret = sess.run(op, feed_dict: (input, np.array(1, 2, 3, 4, 5, 6).astype(NPTypeCode.Single) + 0.1f));
186+
187+
ret.Should().BeOfType<byte>().And.BeShaped(2, 3).And.BeOfValues(1, 2, 3, 4, 5, 6);
188+
print(ret.dtype);
189+
print(ret);
190+
}
134191
}
135192
}

0 commit comments

Comments
 (0)