Skip to content

Commit a4e4c36

Browse files
committed
protobuf SciSharp#3
1 parent 5a58965 commit a4e4c36

11 files changed

Lines changed: 202 additions & 28 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,3 +334,4 @@ ASALocalRun/
334334
/tensorflowlib/linux/native/libtensorflow_framework.so
335335
/tensorflowlib/linux/native/libtensorflow.so
336336
/src/TensorFlowNET.Core/libtensorflow.dll
337+
/src/TensorFlowNET.Core/tensorflow.dll

src/TensorFlowNET.Core/Graph.cs

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
using System.Linq;
44
using System.Runtime.InteropServices;
55
using System.Text;
6-
6+
using Tensorflow;
77
using TF_DataType = Tensorflow.DataType;
88

99
namespace TensorFlowNET.Core
@@ -19,6 +19,7 @@ public class Graph
1919
public IntPtr handle;
2020
private Dictionary<int, Operation> _nodes_by_id;
2121
private Dictionary<string, Operation> _nodes_by_name;
22+
private Dictionary<string, int> _names_in_use;
2223
public int _version;
2324
private int _next_id_counter;
2425

@@ -27,17 +28,20 @@ public Graph(IntPtr graph)
2728
this.handle = graph;
2829
_nodes_by_id = new Dictionary<int, Operation>();
2930
_nodes_by_name = new Dictionary<string, Operation>();
31+
_names_in_use = new Dictionary<string, int>();
3032
}
3133

32-
public unsafe Operation create_op(string op_type, object inputs, TF_DataType[] dtypes, TF_DataType[] input_types = null, string name = "")
34+
public unsafe Operation create_op(string op_type, object inputs, TF_DataType[] dtypes, TF_DataType[] input_types = null, Dictionary<string, AttrValue> attrs = null, string name = "Const")
3335
{
3436
if (String.IsNullOrEmpty(name))
3537
{
36-
op_type = name;
38+
name = op_type;
3739
}
3840

39-
var op = new Operation(this, inputs);
40-
op.name = name;
41+
name = unique_name(name);
42+
var node_def = ops._NodeDef(op_type, name, device: "", attrs: attrs);
43+
44+
var op = new Operation(node_def, this, inputs, dtypes);
4145

4246
return op;
4347
}
@@ -54,6 +58,22 @@ public int _next_id()
5458
return ++_next_id_counter;
5559
}
5660

61+
public string unique_name(string name)
62+
{
63+
var name_key = name.ToLower();
64+
if (_names_in_use.ContainsKey(name_key))
65+
{
66+
_names_in_use[name_key]++;
67+
}
68+
else
69+
{
70+
_names_in_use[name_key] = 1;
71+
}
72+
73+
74+
return $"{name}_{_names_in_use[name_key]}";
75+
}
76+
5777
public Operation[] get_operations()
5878
{
5979
return _nodes_by_name.Values.Select(x => x).ToArray();

src/TensorFlowNET.Core/Operation.cs

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
using System;
22
using System.Collections.Generic;
33
using System.Text;
4+
using Tensorflow;
5+
using TF_DataType = Tensorflow.DataType;
46

57
namespace TensorFlowNET.Core
68
{
@@ -11,13 +13,23 @@ public class Operation
1113
public int _id => _id_value;
1214
private int _id_value;
1315
public string name;
16+
private Tensor[] _outputs;
17+
public Tensor[] outputs => _outputs;
1418

15-
public Operation(Graph g, object inputs)
19+
public Operation(NodeDef node_def, Graph g, object inputs = null, TF_DataType[] output_types = null, object control_inputs = null, TF_DataType[] input_types = null, string original_op = "", string op_def = "")
1620
{
1721
_graph = g;
1822

1923
_id_value = _graph._next_id();
20-
_c_op = ops._create_c_op(g, inputs);
24+
_c_op = ops._create_c_op(g, node_def, inputs);
25+
var num_outputs = c_api.TF_OperationNumOutputs(_c_op);
26+
27+
_outputs = new Tensor[num_outputs];
28+
for (int i = 0; i < num_outputs; i++)
29+
{
30+
_outputs[i] = new Tensor(this, i, TF_DataType.DtDouble);
31+
}
32+
2133
_graph._add_op(this);
2234
}
2335
}

src/TensorFlowNET.Core/Tensor.cs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,21 @@
11
using System;
22
using System.Collections.Generic;
33
using System.Text;
4+
using Tensorflow;
45

56
namespace TensorFlowNET.Core
67
{
78
public class Tensor
89
{
10+
private Operation _op;
11+
private int _value_index;
12+
private DataType _dtype;
13+
14+
public Tensor(Operation op, int value_index, DataType dtype)
15+
{
16+
_op = op;
17+
_value_index = value_index;
18+
_dtype = dtype;
19+
}
920
}
1021
}

src/TensorFlowNET.Core/TensorFlowNET.Core.csproj

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@
99
<DefineConstants>DEBUG;TRACE</DefineConstants>
1010
</PropertyGroup>
1111

12+
<ItemGroup>
13+
<None Remove="Tensorflow\README.md" />
14+
</ItemGroup>
15+
1216
<ItemGroup>
1317
<PackageReference Include="Google.Protobuf" Version="3.6.1" />
1418
</ItemGroup>
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
using Google.Protobuf.Collections;
2+
using System;
3+
using System.Collections.Generic;
4+
using System.Text;
5+
using Tensorflow;
6+
using tensor_shape_pb2 = Tensorflow;
7+
8+
namespace TensorFlowNET.Core
9+
{
10+
public class TensorShape
11+
{
12+
private int[] _dims;
13+
14+
public TensorShape()
15+
{
16+
17+
}
18+
19+
public TensorShape as_shape()
20+
{
21+
return this;
22+
}
23+
24+
public TensorShapeProto as_proto()
25+
{
26+
TensorShapeProto dim = new TensorShapeProto();
27+
28+
return new TensorShapeProto(dim);
29+
}
30+
}
31+
}

src/TensorFlowNET.Core/Tensorflow.cs

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
using System.Runtime.InteropServices;
44
using System.Text;
55
using TF_DataType = Tensorflow.DataType;
6+
using attr_value_pb2 = Tensorflow;
7+
using Tensorflow;
68

79
namespace TensorFlowNET.Core
810
{
@@ -13,9 +15,20 @@ public static class Tensorflow
1315
public static unsafe Tensor constant(object value)
1416
{
1517
var g = ops.get_default_graph();
16-
g.create_op("Const", value, new TF_DataType[] { TF_DataType.DtDouble });
17-
18-
return new Tensor();
18+
var tensor_value = new attr_value_pb2.AttrValue();
19+
var tensor_pb = tensor_util.make_tensor_proto(value);
20+
tensor_value.Tensor = tensor_pb;
21+
var dtype_value = new attr_value_pb2.AttrValue
22+
{
23+
Type = tensor_value.Tensor.Dtype,
24+
};
25+
26+
var attrs = new Dictionary<string, AttrValue>();
27+
attrs["dtype"] = dtype_value;
28+
attrs["value"] = tensor_value;
29+
var const_tensor = g.create_op("Const", null, new TF_DataType[] { dtype_value.Type }, attrs: attrs).outputs[0];
30+
31+
return const_tensor;
1932
}
2033

2134
public static Deallocator FreeTensorDataDelegate = FreeTensorData;

src/TensorFlowNET.Core/c_api.cs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ namespace TensorFlowNET.Core
1818
{
1919
public static class c_api
2020
{
21-
public const string TensorFlowLibName = "libtensorflow";
21+
public const string TensorFlowLibName = "tensorflow";
2222

2323
[DllImport(TensorFlowLibName)]
2424
public static unsafe extern TF_Operation TF_FinishOperation(TF_OperationDescription desc, TF_Status status);
@@ -35,6 +35,12 @@ public static class c_api
3535
[DllImport(TensorFlowLibName)]
3636
public static extern unsafe TF_Tensor TF_NewTensor(TF_DataType dataType, Int64 dims, int num_dims, IntPtr data, size_t len, Deallocator deallocator, IntPtr deallocator_arg);
3737

38+
[DllImport(TensorFlowLibName)]
39+
public static extern unsafe int TF_OperationNumOutputs(TF_Operation oper);
40+
41+
[DllImport(TensorFlowLibName)]
42+
public static extern unsafe void TF_SetAttrValueProto(TF_OperationDescription desc, string attr_name, void* proto, size_t proto_len, TF_Status status);
43+
3844
[DllImport(TensorFlowLibName)]
3945
public static extern unsafe void TF_SetAttrTensor(TF_OperationDescription desc, string attr_name, TF_Tensor value, TF_Status status);
4046

src/TensorFlowNET.Core/ops.cs

Lines changed: 59 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
using System.Runtime.InteropServices;
44
using System.Text;
55
using System.Threading;
6+
using Tensorflow;
67
using tf = TensorFlowNET.Core.Tensorflow;
78
using TF_DataType = Tensorflow.DataType;
9+
using node_def_pb2 = Tensorflow;
810

911
namespace TensorFlowNET.Core
1012
{
@@ -15,28 +17,73 @@ public static Graph get_default_graph()
1517
return tf.Graph();
1618
}
1719

18-
public static unsafe IntPtr _create_c_op(Graph graph, object inputs)
20+
public static unsafe IntPtr _create_c_op(Graph graph, NodeDef node_def, object inputs)
1921
{
20-
var op_desc = c_api.TF_NewOperation(graph.handle, "Const", "Const0");
22+
var op_desc = c_api.TF_NewOperation(graph.handle, node_def.Op, node_def.Name);
2123
var status = c_api.TF_NewStatus();
2224

23-
IntPtr tensor = IntPtr.Zero;
25+
// Doesn't work
26+
/*foreach(var attr in node_def.Attr)
27+
{
28+
if (attr.Value.Tensor != null)
29+
{
30+
switch (attr.Value.Tensor.Dtype)
31+
{
32+
case DataType.DtDouble:
33+
var proto = (double*)Marshal.AllocHGlobal(sizeof(double));
34+
*proto = attr.Value.Tensor.DoubleVal[0];
35+
c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: (UIntPtr)sizeof(double), status: status);
36+
break;
37+
}
38+
}
39+
else
40+
{
41+
//c_api.TF_SetAttrValueProto(op_desc, attr.Key, null, proto_len: UIntPtr.Zero, status: status);
42+
}
43+
} */
2444

25-
switch (inputs)
45+
foreach (var attr in node_def.Attr)
2646
{
27-
case double value:
28-
var v = (double*)Marshal.AllocHGlobal(sizeof(double));
29-
*v = value;
30-
tensor = c_api.TF_NewTensor(TF_DataType.DtDouble, 0, 0, data: (IntPtr)v, len: (UIntPtr)sizeof(double), deallocator: Tensorflow.FreeTensorDataDelegate, deallocator_arg: IntPtr.Zero);
31-
c_api.TF_SetAttrType(op_desc, "dtype", TF_DataType.DtDouble);
32-
break;
47+
if (attr.Value.Tensor == null) continue;
48+
switch (attr.Value.Tensor.Dtype)
49+
{
50+
case DataType.DtDouble:
51+
var v = (double*)Marshal.AllocHGlobal(sizeof(double));
52+
*v = attr.Value.Tensor.DoubleVal[0];
53+
var tensor = c_api.TF_NewTensor(TF_DataType.DtDouble, 0, 0, data: (IntPtr)v, len: (UIntPtr)sizeof(double), deallocator: Tensorflow.FreeTensorDataDelegate, deallocator_arg: IntPtr.Zero);
54+
c_api.TF_SetAttrTensor(op_desc, "value", tensor, status);
55+
c_api.TF_SetAttrType(op_desc, "dtype", TF_DataType.DtDouble);
56+
break;
57+
case DataType.DtString:
58+
59+
var proto = Marshal.StringToHGlobalAnsi(attr.Value.Tensor.StringVal[0].ToStringUtf8());
60+
c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto.ToPointer(), proto_len: (UIntPtr)32, status: status);
61+
break;
62+
}
3363
}
3464

35-
c_api.TF_SetAttrTensor(op_desc, "value", tensor, status);
36-
3765
var c_op = c_api.TF_FinishOperation(op_desc, status);
3866

3967
return c_op;
4068
}
69+
70+
public static NodeDef _NodeDef(string op_type, string name, string device = "", Dictionary<string, AttrValue> attrs = null)
71+
{
72+
var node_def = new node_def_pb2.NodeDef();
73+
node_def.Op = op_type;
74+
node_def.Name = name;
75+
76+
foreach (var attr in attrs)
77+
{
78+
node_def.Attr.Add(attr.Key, attr.Value);
79+
}
80+
81+
return node_def;
82+
}
83+
84+
public static int uid()
85+
{
86+
return 1;
87+
}
4188
}
4289
}
Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,45 @@
1-
using System;
1+
using NumSharp.Core;
2+
using System;
23
using System.Collections.Generic;
34
using System.Text;
5+
using Tensorflow;
46
using np = NumSharp.Core.NumPy;
7+
using tensor_pb2 = Tensorflow;
58

69
namespace TensorFlowNET.Core
710
{
811
public static class tensor_util
912
{
10-
public static void make_tensor_proto(object values, Type dtype = null)
13+
public static TensorProto make_tensor_proto(object values, Type dtype = null)
1114
{
12-
var nparray = np.array(values as Array, dtype);
15+
NDArray nparray;
16+
TensorProto tensor_proto = null;
17+
TensorShape tensor_shape = new TensorShape();
18+
19+
switch (values)
20+
{
21+
case double val:
22+
nparray = np.array(new double[] { val }, np.float64);
23+
tensor_proto = new tensor_pb2.TensorProto
24+
{
25+
Dtype = DataType.DtDouble,
26+
TensorShape = tensor_shape.as_shape().as_proto()
27+
};
28+
tensor_proto.DoubleVal.Add(val);
29+
break;
30+
31+
case string val:
32+
nparray = np.array(new string[] { val }, np.chars);
33+
tensor_proto = new tensor_pb2.TensorProto
34+
{
35+
Dtype = DataType.DtString,
36+
TensorShape = tensor_shape.as_shape().as_proto()
37+
};
38+
tensor_proto.StringVal.Add(Google.Protobuf.ByteString.CopyFrom(val, Encoding.UTF8));
39+
break;
40+
}
41+
42+
return tensor_proto;
1343
}
1444
}
1545
}

0 commit comments

Comments
 (0)