Skip to content

Commit fe4a06f

Browse files
committed
make_tensor_proto: supported additional types int[,] long[] long[,] float[,] double[] double[,] and byte[,]
1 parent 80a108b commit fe4a06f

2 files changed

Lines changed: 192 additions & 3 deletions

File tree

src/TensorFlowNET.Core/Tensors/tensor_util.cs

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,10 +116,19 @@ public static TensorProto make_tensor_proto(object values, TF_DataType dtype = T
116116
case int intVal:
117117
nparray = intVal;
118118
break;
119+
case int[] intVals:
120+
nparray = np.array(intVals);
121+
break;
122+
case int[,] intVals:
123+
nparray = np.array(intVals);
124+
break;
119125
case long intVal:
120126
nparray = intVal;
121127
break;
122-
case int[] intVals:
128+
case long[] intVals:
129+
nparray = np.array(intVals);
130+
break;
131+
case long[,] intVals:
123132
nparray = np.array(intVals);
124133
break;
125134
case float floatVal:
@@ -128,9 +137,18 @@ public static TensorProto make_tensor_proto(object values, TF_DataType dtype = T
128137
case float[] floatVals:
129138
nparray = floatVals;
130139
break;
140+
case float[,] floatVals:
141+
nparray = np.array(floatVals);
142+
break;
131143
case double doubleVal:
132144
nparray = doubleVal;
133145
break;
146+
case double[] doubleVals:
147+
nparray = np.array(doubleVals);
148+
break;
149+
case double[,] doubleVals:
150+
nparray = np.array(doubleVals);
151+
break;
134152
case string strVal:
135153
nparray = strVal;
136154
break;
@@ -140,8 +158,11 @@ public static TensorProto make_tensor_proto(object values, TF_DataType dtype = T
140158
case byte[] byteValues:
141159
nparray = byteValues;
142160
break;
161+
case byte[,] byteValues:
162+
nparray = np.array(byteValues);
163+
break;
143164
default:
144-
throw new NotImplementedException("make_tensor_proto Not Implemented");
165+
throw new NotImplementedException($"make_tensor_proto: Support for type {values.GetType()} Not Implemented");
145166
}
146167
}
147168
else
@@ -174,7 +195,7 @@ public static TensorProto make_tensor_proto(object values, TF_DataType dtype = T
174195
nparray = Convert.ToString(values);
175196
break;
176197
default:
177-
throw new NotImplementedException("make_tensor_proto Not Implemented");
198+
throw new NotImplementedException($"make_tensor_proto: Support for type {np_dt.Name} Not Implemented");
178199
}
179200
}
180201
}
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using Microsoft.VisualStudio.TestTools.UnitTesting;
5+
using Tensorflow;
6+
7+
namespace TensorFlowNET.UnitTest
8+
{
9+
/// <summary>
10+
/// excerpt of tensorflow/python/framework/ops_test.py
11+
/// # These cases test the private Graph._create_op_from_tf_operation
12+
/// # method. Arguably we should only test the public APIs that depend on this
13+
/// # method. However, this logic is complex and tricky, and it can be difficult to
14+
/// # ascertain if we have adequate coverage (e.g. a graph may run successfully if
15+
/// # the control flow context isn't set properly, but a more complicated use case
16+
/// # that might not be obvious to test will fail). Thus we instead explicitly test
17+
/// # the low-level behavior.
18+
/// </summary>
19+
[TestClass]
20+
public class CreateOpFromTfOperationTest : PythonTest
21+
{
22+
23+
[TestMethod]
24+
public void TestShape()
25+
{
26+
var graph = tf.Graph().as_default();
27+
with<Graph>(graph, g =>
28+
{
29+
var x = constant_op.constant(new [,] { {1, 2, 3}, {4, 5, 6}});
30+
var (c_op, op_desc) = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), new[] {x}, new Operation[0]);
31+
var op = g._create_op_from_tf_operation(c_op);
32+
33+
Assert.AreEqual("myop", op.name);
34+
Assert.AreEqual("Identity", op.type);
35+
Assert.AreEqual(1, len(op.outputs));
36+
AssertItemsEqual(new []{2, 3}, op.outputs[0].shape);
37+
});
38+
}
39+
40+
/*def testUniqueName(self):
41+
g = ops.Graph()
42+
with g.as_default():
43+
c_op = ops._create_c_op(g, ops._NodeDef("IntOutput", "myop"), [], [])
44+
c_op2 = ops._create_c_op(g, ops._NodeDef("IntOutput", "myop_1"), [], [])
45+
op = g._create_op_from_tf_operation(c_op)
46+
op2 = g._create_op_from_tf_operation(c_op2)
47+
48+
# Create ops with same names as op1 and op2. We expect the new names to be
49+
# uniquified.
50+
op3 = test_ops.int_output(name="myop").op
51+
op4 = test_ops.int_output(name="myop_1").op
52+
53+
self.assertEqual(op.name, "myop")
54+
self.assertEqual(op2.name, "myop_1")
55+
self.assertEqual(op3.name, "myop_2")
56+
self.assertEqual(op4.name, "myop_1_1")
57+
58+
@test_util.run_v1_only("b/120545219")
59+
def testCond(self):
60+
g = ops.Graph()
61+
with g.as_default():
62+
x = test_ops.int_output()
63+
64+
def true_fn():
65+
ops._create_c_op(ops.get_default_graph(),
66+
ops._NodeDef("IntInput", "cond/myop"), [x], [])
67+
new_ops = g._add_new_tf_operations()
68+
self.assertEqual(len(new_ops), 1)
69+
return x
70+
71+
control_flow_ops.cond(x < 10, true_fn, lambda: x)
72+
73+
op = g.get_operation_by_name("cond/myop")
74+
self.assertIsNotNone(op)
75+
self.assertEqual(op.name, "cond/myop")
76+
self.assertEqual(op.type, "IntInput")
77+
self.assertEqual(op.outputs, [])
78+
op_input = op.inputs[0].op
79+
self.assertEqual(op_input.type, "Switch")
80+
self.assertEqual(op_input.inputs[0], x)
81+
self.assertEqual(op.graph, g)
82+
# pylint: disable=protected-access
83+
self.assertIsNotNone(op._get_control_flow_context())
84+
self.assertEqual(op._get_control_flow_context().name,
85+
"cond/cond_text")
86+
# pylint: enable=protected-access
87+
88+
@test_util.run_v1_only("b/120545219")
89+
def testWhileLoop(self):
90+
g = ops.Graph()
91+
with g.as_default():
92+
x = test_ops.int_output()
93+
94+
def body(i):
95+
ops._create_c_op(ops.get_default_graph(),
96+
ops._NodeDef("IntInput", "myloop/myop"), [x], [])
97+
new_ops = g._add_new_tf_operations()
98+
self.assertEqual(len(new_ops), 1)
99+
return i
100+
101+
control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
102+
103+
op = g.get_operation_by_name("myloop/myop")
104+
self.assertIsNotNone(op)
105+
self.assertEqual(op.name, "myloop/myop")
106+
self.assertEqual(op.type, "IntInput")
107+
self.assertEqual(op.outputs, [])
108+
op_input = op.inputs[0].op
109+
self.assertEqual(op_input.type, "Enter")
110+
self.assertEqual(list(op_input.inputs), [x])
111+
self.assertEqual(op.graph, g)
112+
# pylint: disable=protected-access
113+
self.assertIsNotNone(op._get_control_flow_context())
114+
self.assertEqual(op._get_control_flow_context().name,
115+
"myloop/while_context")
116+
# pylint: enable=protected-access
117+
118+
@test_util.run_v1_only("b/120545219")
119+
def testWhileLoopWithInternalControlDep(self):
120+
g = ops.Graph()
121+
with g.as_default():
122+
x = test_ops.int_output()
123+
124+
def body(i):
125+
c = constant_op.constant(1.0, name="c")
126+
ops._create_c_op(ops.get_default_graph(),
127+
ops._NodeDef("IntInput", "myloop/myop"), [x], [])
128+
with ops.control_dependencies([c]):
129+
new_ops = g._add_new_tf_operations()
130+
self.assertEqual(len(new_ops), 1)
131+
return i
132+
133+
control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
134+
135+
op = g.get_operation_by_name("myloop/myop")
136+
self.assertIsNotNone(op)
137+
c = g.get_operation_by_name("myloop/c")
138+
self.assertIsNotNone(c)
139+
# Internal control dep is preserved
140+
self.assertEqual(op.control_inputs, [c])
141+
142+
@test_util.run_v1_only("b/120545219")
143+
def testWhileLoopWithExternalControlDep(self):
144+
g = ops.Graph()
145+
with g.as_default():
146+
x = test_ops.int_output()
147+
c = constant_op.constant(1.0)
148+
149+
def body(i):
150+
ops._create_c_op(ops.get_default_graph(),
151+
ops._NodeDef("IntInput", "myloop/myop"), [x], [])
152+
with ops.control_dependencies([c]):
153+
new_ops = g._add_new_tf_operations()
154+
self.assertEqual(len(new_ops), 1)
155+
return i
156+
157+
control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
158+
159+
op = g.get_operation_by_name("myloop/myop")
160+
self.assertIsNotNone(op)
161+
# External control dep is removed and replaced with internal control dep
162+
self.assertNotEqual(op.control_inputs[0], c.op)
163+
self.assertIsNotNone(op.control_inputs[0]._get_control_flow_context())
164+
165+
166+
*/
167+
}
168+
}

0 commit comments

Comments
 (0)