Skip to content

Commit 468cb8e

Browse files
committed
add TF_GraphSetOutputHandleShapesAndTypes api.
1 parent 9bb603d commit 468cb8e

3 files changed

Lines changed: 22 additions & 6 deletions

File tree

src/TensorFlowNET.Core/Graphs/Graph.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,9 @@ public partial class Graph : DisposableObject
105105

106106
public bool building_function;
107107

108+
string _container = "";
109+
public string Container => _container;
110+
108111
int _seed;
109112
public int seed
110113
{
@@ -151,6 +154,8 @@ private Tensor _as_graph_element(object obj)
151154
{
152155
if (obj is RefVariable var)
153156
return var._as_graph_element();
157+
else if (obj is ResourceVariable resVar)
158+
return resVar.GraphElement;
154159

155160
return null;
156161
}

src/TensorFlowNET.Core/Graphs/c_api.graph.cs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,21 @@ public static extern IntPtr TF_LoadSessionFromSavedModel(SafeSessionOptionsHandl
297297
[DllImport(TensorFlowLibName)]
298298
public static extern SafeImportGraphDefOptionsHandle TF_NewImportGraphDefOptions();
299299

300+
/// <summary>
301+
/// Set the shapes and types of the output's handle.
302+
/// </summary>
303+
/// <param name="graph">TF_Graph*</param>
304+
/// <param name="output">TF_Output</param>
305+
/// <param name="num_shapes_and_types">int</param>
306+
/// <param name="shapes">const int64_t**</param>
307+
/// <param name="ranks">const int*</param>
308+
/// <param name="types">const TF_DataType*</param>
309+
/// <param name="status">TF_Status*</param>
310+
[DllImport(TensorFlowLibName)]
311+
public static extern void TF_GraphSetOutputHandleShapesAndTypes(IntPtr graph, TF_Output output,
312+
int num_shapes_and_types, IntPtr[] shapes, int[] ranks, DataType[] types,
313+
SafeStatusHandle status);
314+
300315
/// <summary>
301316
/// Updates 'dst' to consume 'new_src'.
302317
/// </summary>

test/TensorFlowNET.UnitTest/MultithreadingTests.cs

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -145,10 +145,9 @@ void Core(int tid)
145145
//tf.Session created an other graph
146146
using (var sess = tf.Session())
147147
{
148-
Tensor t = null;
149148
for (int i = 0; i < 100; i++)
150149
{
151-
t = new Tensor(new int[] {1, 2, 3});
150+
var t = new Tensor(new int[] {1, 2, 3});
152151
}
153152
}
154153
}
@@ -167,12 +166,9 @@ unsafe void Core(int tid)
167166
{
168167
using (var sess = tf.Session())
169168
{
170-
#pragma warning disable CS0219 // Variable is assigned but its value is never used
171-
Tensor t = null;
172-
#pragma warning restore CS0219 // Variable is assigned but its value is never used
173169
for (int i = 0; i < 100; i++)
174170
{
175-
var v = (int*) Marshal.AllocHGlobal(sizeof(int));
171+
var v = (int*)Marshal.AllocHGlobal(sizeof(int));
176172
c_api.DeallocatorArgs _deallocatorArgs = new c_api.DeallocatorArgs();
177173
var handle = c_api.TF_NewTensor(typeof(int).as_dtype(), dims: new long[0], num_dims: 0,
178174
data: (IntPtr) v, len: (UIntPtr) sizeof(int),

0 commit comments

Comments
 (0)