Skip to content

Commit 36272d6

Browse files
committed
TEST(CAPI, AllocateTensor)
1 parent 1f6ea31 commit 36272d6

13 files changed

Lines changed: 205 additions & 51 deletions

File tree

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ TensorFlow.NET provides .NET Standard binding for [TensorFlow](https://www.tenso
44
[![Join the chat at https://gitter.im/publiclab/publiclab](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/sci-sharp/community)
55
![Tensorflow.NET](https://ci.appveyor.com/api/projects/status/tensorflow-net-p7kmsjyo10ey?svg=true)
66

7-
TensorFlow.NET is a member project of SciSharp stack.
7+
TensorFlow.NET is a member project of [SciSharp](https://github.com/SciSharp) stack.
88

99
![tensors_flowing](docs/assets/tensors_flowing.gif)
1010

@@ -45,3 +45,5 @@ using(var sess = tf.Session())
4545
var o = sess.run(c, feed_dict);
4646
}
4747
```
48+
49+
Star me or raise issue on [Github](https://github.com/SciSharp/TensorFlow.NET) feel free.

src/TensorFlowNET.Core/Graphs/Graph.cs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ namespace Tensorflow
1313
/// then create a TensorFlow session to run parts of the graph across a set of local and remote devices.
1414
/// https://www.tensorflow.org/guide/graphs
1515
/// </summary>
16-
public class Graph
16+
public class Graph : IDisposable
1717
{
1818
private IntPtr _handle;
1919
private Dictionary<int, Operation> _nodes_by_id;
@@ -25,6 +25,11 @@ public class Graph
2525

2626
private string _name_stack;
2727

28+
public Graph()
29+
{
30+
_handle = c_api.TF_NewGraph();
31+
}
32+
2833
public Graph(IntPtr graph)
2934
{
3035
_handle = graph;
@@ -171,6 +176,11 @@ public Operation[] get_operations()
171176
return _nodes_by_name.Values.Select(x => x).ToArray();
172177
}
173178

179+
public void Dispose()
180+
{
181+
c_api.TF_DeleteGraph(_handle);
182+
}
183+
174184
public static implicit operator IntPtr(Graph graph)
175185
{
176186
return graph._handle;

src/TensorFlowNET.Core/Graphs/c_api.graph.cs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,14 @@ namespace Tensorflow
77
{
88
public static partial class c_api
99
{
10+
/// <summary>
11+
/// Destroy an options object. Graph will be deleted once no more
12+
/// TFSession's are referencing it.
13+
/// </summary>
14+
/// <param name="graph"></param>
15+
[DllImport(TensorFlowLibName)]
16+
public static extern void TF_DeleteGraph(IntPtr graph);
17+
1018
[DllImport(TensorFlowLibName)]
1119
public static extern void TF_GraphGetOpDef(IntPtr graph, string op_name, IntPtr output_op_def, IntPtr status);
1220

@@ -21,14 +29,14 @@ public static partial class c_api
2129
/// <param name="num_dims"></param>
2230
/// <param name="status"></param>
2331
[DllImport(TensorFlowLibName)]
24-
public static extern void TF_GraphGetTensorShape(IntPtr graph, TF_Output output, int[] dims, int num_dims, IntPtr status);
32+
public static extern void TF_GraphGetTensorShape(IntPtr graph, TF_Output output, long[] dims, int num_dims, IntPtr status);
2533

2634
/// <summary>
2735
/// Sets the shape of the Tensor referenced by `output` in `graph` to
2836
/// the shape described by `dims` and `num_dims`.
2937
/// </summary>
3038
[DllImport(TensorFlowLibName)]
31-
public static extern void TF_GraphSetTensorShape(IntPtr graph, TF_Output output, int[] dims, int num_dims, IntPtr status);
39+
public static extern void TF_GraphSetTensorShape(IntPtr graph, TF_Output output, long[] dims, int num_dims, IntPtr status);
3240

3341
/// <summary>
3442
/// Returns the number of dimensions of the Tensor referenced by `output`

src/TensorFlowNET.Core/Operations/Operation.cs

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,24 +14,38 @@ public class Operation
1414

1515
private Status status = new Status();
1616

17-
public string name => c_api.TF_OperationName(_handle);
18-
public string optype => c_api.TF_OperationOpType(_handle);
19-
public string device => c_api.TF_OperationDevice(_handle);
20-
public int NumOutputs => c_api.TF_OperationNumOutputs(_handle);
21-
public TF_DataType OutputType => c_api.TF_OperationOutputType(new TF_Output(_handle, 0));
22-
public int OutputListLength => c_api.TF_OperationOutputListLength(_handle, "output", status);
23-
public int NumInputs => c_api.TF_OperationNumInputs(_handle);
24-
public int NumConsumers => c_api.TF_OperationOutputNumConsumers(new TF_Output(_handle, 0));
25-
public int NumControlInputs => c_api.TF_OperationNumControlInputs(_handle);
26-
public int NumControlOutputs => c_api.TF_OperationNumControlOutputs(_handle);
17+
public string name { get; }
18+
public string optype { get; }
19+
public string device { get; }
20+
public int NumOutputs { get; }
21+
public TF_DataType OutputType { get; }
22+
public int OutputListLength { get; }
23+
public int NumInputs { get; }
24+
public int NumConsumers { get; }
25+
public int NumControlInputs { get; }
26+
public int NumControlOutputs { get; }
2727

2828
private Tensor[] _outputs;
2929
public Tensor[] outputs => _outputs;
3030
public Tensor[] inputs;
3131

3232
public Operation(IntPtr handle)
3333
{
34+
if (handle == IntPtr.Zero)
35+
return;
36+
3437
_handle = handle;
38+
39+
name = c_api.TF_OperationName(_handle);
40+
optype = c_api.TF_OperationOpType(_handle);
41+
device = "";// c_api.TF_OperationDevice(_handle);
42+
NumOutputs = c_api.TF_OperationNumOutputs(_handle);
43+
OutputType = c_api.TF_OperationOutputType(new TF_Output(_handle, 0));
44+
OutputListLength = c_api.TF_OperationOutputListLength(_handle, "output", status);
45+
NumInputs = c_api.TF_OperationNumInputs(_handle);
46+
NumConsumers = c_api.TF_OperationOutputNumConsumers(new TF_Output(_handle, 0));
47+
NumControlInputs = c_api.TF_OperationNumControlInputs(_handle);
48+
NumControlOutputs = c_api.TF_OperationNumControlOutputs(_handle);
3549
}
3650

3751
public Operation(Graph g, string opType, string oper_name)

src/TensorFlowNET.Core/Operations/TF_Output.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ public TF_Output(IntPtr oper, int index)
1414
this.index = index;
1515
}
1616

17-
public IntPtr oper;
17+
public unsafe IntPtr oper;
1818
public int index;
1919
}
2020
}

src/TensorFlowNET.Core/Operations/c_api.ops.cs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,15 @@ public static partial class c_api
2222
[DllImport(TensorFlowLibName)]
2323
public static extern void TF_AddInput(IntPtr desc, TF_Output input);
2424

25+
/// <summary>
26+
/// For inputs that take a list of tensors.
27+
/// inputs must point to TF_Output[num_inputs].
28+
/// </summary>
29+
/// <param name="desc"></param>
30+
/// <param name="inputs"></param>
31+
[DllImport(TensorFlowLibName)]
32+
public static extern void TF_AddInputList(IntPtr desc, TF_Output[] inputs, int num_inputs);
33+
2534
[DllImport(TensorFlowLibName)]
2635
public static extern IntPtr TF_FinishOperation(IntPtr desc, IntPtr status);
2736

src/TensorFlowNET.Core/Tensors/Tensor.cs

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ namespace Tensorflow
1111
/// A tensor is a generalization of vectors and matrices to potentially higher dimensions.
1212
/// Internally, TensorFlow represents tensors as n-dimensional arrays of base datatypes.
1313
/// </summary>
14-
public class Tensor
14+
public class Tensor : IDisposable
1515
{
1616
private readonly IntPtr _handle;
1717

@@ -38,6 +38,7 @@ public class Tensor
3838
/// n n-Tensor (you get the idea)
3939
/// </summary>
4040
public int rank;
41+
public int NDims => rank;
4142

4243
/// <summary>
4344
/// if original buffer is free.
@@ -96,7 +97,7 @@ private IntPtr Allocate(NDArray nd)
9697
nd.shape.Select(x => (long)x).ToArray(), // shape
9798
nd.ndim,
9899
dotHandle,
99-
(UIntPtr)(nd.size * nd.dtypesize),
100+
(ulong)(nd.size * nd.dtypesize),
100101
(IntPtr values, IntPtr len, ref bool closure) =>
101102
{
102103
// Free the original buffer and set flag
@@ -160,9 +161,19 @@ public TF_DataType ToTFDataType(Type type)
160161
return TF_DataType.DtInvalid;
161162
}
162163

164+
public void Dispose()
165+
{
166+
c_api.TF_DeleteTensor(_handle);
167+
}
168+
163169
public static implicit operator IntPtr(Tensor tensor)
164170
{
165171
return tensor._handle;
166172
}
173+
174+
public static implicit operator Tensor(IntPtr handle)
175+
{
176+
return new Tensor(handle);
177+
}
167178
}
168179
}

src/TensorFlowNET.Core/Tensors/c_api.tensor.cs

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,23 @@ namespace Tensorflow
77
{
88
public static partial class c_api
99
{
10+
[DllImport(TensorFlowLibName)]
11+
public static extern IntPtr TF_AllocateTensor(TF_DataType dtype, long[] dims, int num_dims, ulong len);
12+
1013
/// <summary>
1114
/// returns the sizeof() for the underlying type corresponding to the given TF_DataType enum value.
1215
/// </summary>
1316
/// <param name="dt"></param>
1417
/// <returns></returns>
1518
[DllImport(TensorFlowLibName)]
16-
public static unsafe extern ulong TF_DataTypeSize(TF_DataType dt);
19+
public static extern ulong TF_DataTypeSize(TF_DataType dt);
1720

1821
/// <summary>
1922
/// Destroy a tensor.
2023
/// </summary>
2124
/// <param name="tensor"></param>
2225
[DllImport(TensorFlowLibName)]
23-
public static unsafe extern void TF_DeleteTensor(IntPtr tensor);
26+
public static extern void TF_DeleteTensor(IntPtr tensor);
2427

2528
/// <summary>
2629
/// Return the length of the tensor in the "dim_index" dimension.
@@ -30,7 +33,7 @@ public static partial class c_api
3033
/// <param name="dim_index"></param>
3134
/// <returns></returns>
3235
[DllImport(TensorFlowLibName)]
33-
public static extern unsafe long TF_Dim(IntPtr tensor, int dim_index);
36+
public static extern long TF_Dim(IntPtr tensor, int dim_index);
3437

3538
/// <summary>
3639
/// Return a new tensor that holds the bytes data[0,len-1]
@@ -44,38 +47,38 @@ public static partial class c_api
4447
/// <param name="deallocator_arg"></param>
4548
/// <returns></returns>
4649
[DllImport(TensorFlowLibName)]
47-
public static extern unsafe IntPtr TF_NewTensor(TF_DataType dataType, long[] dims, int num_dims, IntPtr data, UIntPtr len, Deallocator deallocator, ref bool deallocator_arg);
50+
public static extern IntPtr TF_NewTensor(TF_DataType dataType, long[] dims, int num_dims, IntPtr data, ulong len, Deallocator deallocator, ref bool deallocator_arg);
4851

4952
/// <summary>
5053
/// Return the number of dimensions that the tensor has.
5154
/// </summary>
5255
/// <param name="tensor"></param>
5356
/// <returns></returns>
5457
[DllImport(TensorFlowLibName)]
55-
public static extern unsafe int TF_NumDims(IntPtr tensor);
58+
public static extern int TF_NumDims(IntPtr tensor);
5659

5760
/// <summary>
5861
/// Return the size of the underlying data in bytes.
5962
/// </summary>
6063
/// <param name="tensor"></param>
6164
/// <returns></returns>
6265
[DllImport(TensorFlowLibName)]
63-
public static extern unsafe ulong TF_TensorByteSize(IntPtr tensor);
66+
public static extern ulong TF_TensorByteSize(IntPtr tensor);
6467

6568
/// <summary>
6669
/// Return a pointer to the underlying data buffer.
6770
/// </summary>
6871
/// <param name="tensor"></param>
6972
/// <returns></returns>
7073
[DllImport(TensorFlowLibName)]
71-
public static extern unsafe IntPtr TF_TensorData(IntPtr tensor);
74+
public static extern IntPtr TF_TensorData(IntPtr tensor);
7275

7376
/// <summary>
7477
/// Return the type of a tensor element.
7578
/// </summary>
7679
/// <param name="tensor"></param>
7780
/// <returns></returns>
7881
[DllImport(TensorFlowLibName)]
79-
public static extern unsafe TF_DataType TF_TensorType(IntPtr tensor);
82+
public static extern TF_DataType TF_TensorType(IntPtr tensor);
8083
}
8184
}

src/TensorFlowNET.Core/c_api.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ namespace Tensorflow
77
{
88
/// <summary>
99
/// C API for TensorFlow.
10+
/// Port from tensorflow\c\c_api.h
1011
///
1112
/// The API leans towards simplicity and uniformity instead of convenience
1213
/// since most usage will be by language specific wrappers.

test/TensorFlowNET.UnitTest/GraphTest.cs

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,21 @@ namespace TensorFlowNET.UnitTest
99
[TestClass]
1010
public class GraphTest
1111
{
12+
/// <summary>
13+
/// Port from c_api_test.cc
14+
/// `TEST(CAPI, Graph)`
15+
/// </summary>
1216
[TestMethod]
13-
public void Graph()
17+
public void c_api_Graph()
1418
{
1519
var s = new Status();
16-
var graph = tf.get_default_graph();
20+
var graph = new Graph();
1721

1822
// Make a placeholder operation.
1923
var feed = c_test_util.Placeholder(graph, s);
2024
Assert.AreEqual("feed", feed.name);
2125
Assert.AreEqual("Placeholder", feed.optype);
22-
//Assert.AreEqual("", feed.device);
26+
Assert.AreEqual("", feed.device);
2327
Assert.AreEqual(1, feed.NumOutputs);
2428
Assert.AreEqual(TF_DataType.TF_INT32, feed.OutputType);
2529
Assert.AreEqual(1, feed.OutputListLength);
@@ -30,6 +34,19 @@ public void Graph()
3034

3135
AttrValue attr_value = null;
3236
c_test_util.GetAttrValue(feed, "dtype", ref attr_value, s);
37+
Assert.AreEqual(attr_value.Type, DataType.DtInt32);
38+
39+
// Test not found errors in TF_Operation*() query functions.
40+
// Assert.AreEqual(-1, c_api.TF_OperationOutputListLength(feed, "bogus", s));
41+
// Assert.AreEqual(TF_Code.TF_INVALID_ARGUMENT, s.Code);
42+
// Assert.IsFalse(c_test_util.GetAttrValue(feed, "missing", ref attr_value, s));
43+
// Assert.AreEqual("Operation 'feed' has no attr named 'missing'.", s.Message);
44+
45+
// Make a constant oper with the scalar "3".
46+
var three = c_test_util.ScalarConst(3, graph, s);
47+
48+
// Add oper.
49+
var add = c_test_util.Add(feed, three, graph, s);
3350
}
3451
}
3552
}

0 commit comments

Comments
 (0)