@@ -11,21 +11,19 @@ namespace TensorFlowNET.UnitTest
1111 [ TestClass ]
1212 public class ConstantTest
1313 {
14- Tensor tensor ;
15-
1614 [ TestMethod ]
1715 public void ScalarConst ( )
1816 {
19- tensor = tf . constant ( 8 ) ; // int
20- tensor = tf . constant ( 6.0f ) ; // float
21- tensor = tf . constant ( 6.0 ) ; // double
17+ var tensor1 = tf . constant ( 8 ) ; // int
18+ var tensor2 = tf . constant ( 6.0f ) ; // float
19+ var tensor3 = tf . constant ( 6.0 ) ; // double
2220 }
2321
2422 [ TestMethod ]
2523 public void StringConst ( )
2624 {
2725 string str = "Hello, TensorFlow.NET!" ;
28- tensor = tf . constant ( str ) ;
26+ var tensor = tf . constant ( str ) ;
2927 Python . with < Session > ( tf . Session ( ) , sess =>
3028 {
3129 var result = sess . run ( tensor ) ;
@@ -37,7 +35,7 @@ public void StringConst()
3735 public void ZerosConst ( )
3836 {
3937 // small size
40- tensor = tf . zeros ( new Shape ( 3 , 2 ) , TF_DataType . TF_INT32 , "small" ) ;
38+ var tensor = tf . zeros ( new Shape ( 3 , 2 ) , TF_DataType . TF_INT32 , "small" ) ;
4139 Python . with < Session > ( tf . Session ( ) , sess =>
4240 {
4341 var result = sess . run ( tensor ) ;
@@ -67,11 +65,34 @@ public void NDimConst()
6765 {
6866 var nd = np . array ( new int [ ] [ ]
6967 {
70- new int [ ] { 1 , 2 , 3 } ,
71- new int [ ] { 4 , 5 , 6 }
68+ new int [ ] { 3 , 1 , 1 } ,
69+ new int [ ] { 2 , 1 , 3 }
7270 } ) ;
7371
74- tensor = tf . constant ( nd ) ;
72+ var tensor = tf . constant ( nd ) ;
73+ Python . with < Session > ( tf . Session ( ) , sess =>
74+ {
75+ var result = sess . run ( tensor ) ;
76+ var data = result . Data < int > ( ) ;
77+
78+ Assert . AreEqual ( result . shape [ 0 ] , 2 ) ;
79+ Assert . AreEqual ( result . shape [ 1 ] , 3 ) ;
80+ Assert . IsTrue ( Enumerable . SequenceEqual ( new int [ ] { 3 , 1 , 2 , 1 , 1 , 3 } , data ) ) ;
81+ } ) ;
82+ }
83+
84+ [ TestMethod ]
85+ public void Multiply ( )
86+ {
87+ var a = tf . constant ( 3.0 ) ;
88+ var b = tf . constant ( 2.0 ) ;
89+ var c = a * b ;
90+
91+ var sess = tf . Session ( ) ;
92+ double result = sess . run ( c ) ;
93+ sess . close ( ) ;
94+
95+ Assert . AreEqual ( 6.0 , result ) ;
7596 }
7697 }
7798}
0 commit comments