Skip to content

Commit 7b24f53

Browse files
committed
added TF_Status, TF_SetAttrValueProto
1 parent a4e4c36 commit 7b24f53

7 files changed

Lines changed: 67 additions & 43 deletions

File tree

.gitignore

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -333,5 +333,4 @@ ASALocalRun/
333333
/tensorflowlib/osx/native/libtensorflow.dylib
334334
/tensorflowlib/linux/native/libtensorflow_framework.so
335335
/tensorflowlib/linux/native/libtensorflow.so
336-
/src/TensorFlowNET.Core/libtensorflow.dll
337336
/src/TensorFlowNET.Core/tensorflow.dll

src/TensorFlowNET.Core/Status.cs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using Tensorflow;
5+
6+
namespace TensorFlowNET.Core
7+
{
8+
public class Status
9+
{
10+
private IntPtr _handle;
11+
public IntPtr Handle => _handle;
12+
13+
public string ErrorMessage => c_api.TF_Message(_handle);
14+
15+
public TF_Code Code => c_api.TF_GetCode(_handle);
16+
17+
public Status()
18+
{
19+
_handle = c_api.TF_NewStatus();
20+
}
21+
}
22+
}

src/TensorFlowNET.Core/TensorFlowNET.Core.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
</ItemGroup>
2323

2424
<ItemGroup>
25-
<None Update="libtensorflow.dll">
25+
<None Update="tensorflow.dll">
2626
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
2727
</None>
2828
</ItemGroup>
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow
6+
{
7+
public enum TF_Code
8+
{
9+
TF_OK = 0,
10+
TF_CANCELLED = 1,
11+
TF_UNKNOWN = 2,
12+
TF_INVALID_ARGUMENT = 3,
13+
TF_DEADLINE_EXCEEDED = 4,
14+
TF_NOT_FOUND = 5,
15+
TF_ALREADY_EXISTS = 6,
16+
TF_PERMISSION_DENIED = 7,
17+
TF_UNAUTHENTICATED = 16,
18+
TF_RESOURCE_EXHAUSTED = 8,
19+
TF_FAILED_PRECONDITION = 9,
20+
TF_ABORTED = 10,
21+
TF_OUT_OF_RANGE = 11,
22+
TF_UNIMPLEMENTED = 12,
23+
TF_INTERNAL = 13,
24+
TF_UNAVAILABLE = 14,
25+
TF_DATA_LOSS = 15
26+
}
27+
}

src/TensorFlowNET.Core/c_api.cs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
using TF_Tensor = System.IntPtr;
1212

1313
using TF_DataType = Tensorflow.DataType;
14-
14+
using Tensorflow;
1515
using static TensorFlowNET.Core.Tensorflow;
1616

1717
namespace TensorFlowNET.Core
@@ -23,6 +23,12 @@ public static class c_api
2323
[DllImport(TensorFlowLibName)]
2424
public static unsafe extern TF_Operation TF_FinishOperation(TF_OperationDescription desc, TF_Status status);
2525

26+
[DllImport(TensorFlowLibName)]
27+
public static extern unsafe TF_Code TF_GetCode(TF_Status s);
28+
29+
[DllImport(TensorFlowLibName)]
30+
public static extern unsafe string TF_Message(TF_Status s);
31+
2632
[DllImport(TensorFlowLibName)]
2733
public static unsafe extern IntPtr TF_NewGraph();
2834

@@ -39,7 +45,7 @@ public static class c_api
3945
public static extern unsafe int TF_OperationNumOutputs(TF_Operation oper);
4046

4147
[DllImport(TensorFlowLibName)]
42-
public static extern unsafe void TF_SetAttrValueProto(TF_OperationDescription desc, string attr_name, void* proto, size_t proto_len, TF_Status status);
48+
public static extern unsafe void TF_SetAttrValueProto(TF_OperationDescription desc, string attr_name, IntPtr proto, size_t proto_len, TF_Status status);
4349

4450
[DllImport(TensorFlowLibName)]
4551
public static extern unsafe void TF_SetAttrTensor(TF_OperationDescription desc, string attr_name, TF_Tensor value, TF_Status status);

src/TensorFlowNET.Core/ops.cs

Lines changed: 7 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
using tf = TensorFlowNET.Core.Tensorflow;
88
using TF_DataType = Tensorflow.DataType;
99
using node_def_pb2 = Tensorflow;
10+
using Google.Protobuf;
1011

1112
namespace TensorFlowNET.Core
1213
{
@@ -20,49 +21,17 @@ public static Graph get_default_graph()
2021
public static unsafe IntPtr _create_c_op(Graph graph, NodeDef node_def, object inputs)
2122
{
2223
var op_desc = c_api.TF_NewOperation(graph.handle, node_def.Op, node_def.Name);
23-
var status = c_api.TF_NewStatus();
24-
25-
// Doesn't work
26-
/*foreach(var attr in node_def.Attr)
27-
{
28-
if (attr.Value.Tensor != null)
29-
{
30-
switch (attr.Value.Tensor.Dtype)
31-
{
32-
case DataType.DtDouble:
33-
var proto = (double*)Marshal.AllocHGlobal(sizeof(double));
34-
*proto = attr.Value.Tensor.DoubleVal[0];
35-
c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: (UIntPtr)sizeof(double), status: status);
36-
break;
37-
}
38-
}
39-
else
40-
{
41-
//c_api.TF_SetAttrValueProto(op_desc, attr.Key, null, proto_len: UIntPtr.Zero, status: status);
42-
}
43-
} */
24+
var status = new Status();
4425

4526
foreach (var attr in node_def.Attr)
4627
{
47-
if (attr.Value.Tensor == null) continue;
48-
switch (attr.Value.Tensor.Dtype)
49-
{
50-
case DataType.DtDouble:
51-
var v = (double*)Marshal.AllocHGlobal(sizeof(double));
52-
*v = attr.Value.Tensor.DoubleVal[0];
53-
var tensor = c_api.TF_NewTensor(TF_DataType.DtDouble, 0, 0, data: (IntPtr)v, len: (UIntPtr)sizeof(double), deallocator: Tensorflow.FreeTensorDataDelegate, deallocator_arg: IntPtr.Zero);
54-
c_api.TF_SetAttrTensor(op_desc, "value", tensor, status);
55-
c_api.TF_SetAttrType(op_desc, "dtype", TF_DataType.DtDouble);
56-
break;
57-
case DataType.DtString:
58-
59-
var proto = Marshal.StringToHGlobalAnsi(attr.Value.Tensor.StringVal[0].ToStringUtf8());
60-
c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto.ToPointer(), proto_len: (UIntPtr)32, status: status);
61-
break;
62-
}
28+
var bytes = attr.Value.ToByteArray();
29+
var proto = Marshal.AllocHGlobal(bytes.Length);
30+
Marshal.Copy(bytes, 0, proto, bytes.Length);
31+
c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: (UIntPtr)bytes.Length, status: status.Handle);
6332
}
6433

65-
var c_op = c_api.TF_FinishOperation(op_desc, status);
34+
var c_op = c_api.TF_FinishOperation(op_desc, status.Handle);
6635

6736
return c_op;
6837
}

test/TensorFlowNET.Examples/HelloWorld.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ public void Run()
1919
The value returned by the constructor represents the output
2020
of the Constant op.*/
2121
var graph = tf.get_default_graph();
22-
var hello = tf.constant("Hello, TensorFlow!");
22+
var hello = tf.constant(4.0);
23+
//var hello = tf.constant("Hello, TensorFlow!");
2324

2425
// Start tf session
2526
// var sess = tf.Session();

0 commit comments

Comments
 (0)