33using System . Runtime . InteropServices ;
44using System . Text ;
55using System . Threading ;
6+ using Tensorflow ;
67using tf = TensorFlowNET . Core . Tensorflow ;
78using TF_DataType = Tensorflow . DataType ;
9+ using node_def_pb2 = Tensorflow ;
810
911namespace TensorFlowNET . Core
1012{
@@ -15,28 +17,73 @@ public static Graph get_default_graph()
1517 return tf . Graph ( ) ;
1618 }
1719
18- public static unsafe IntPtr _create_c_op ( Graph graph , object inputs )
20+ public static unsafe IntPtr _create_c_op ( Graph graph , NodeDef node_def , object inputs )
1921 {
20- var op_desc = c_api . TF_NewOperation ( graph . handle , "Const" , "Const0" ) ;
22+ var op_desc = c_api . TF_NewOperation ( graph . handle , node_def . Op , node_def . Name ) ;
2123 var status = c_api . TF_NewStatus ( ) ;
2224
23- IntPtr tensor = IntPtr . Zero ;
25+ // Doesn't work
26+ /*foreach(var attr in node_def.Attr)
27+ {
28+ if (attr.Value.Tensor != null)
29+ {
30+ switch (attr.Value.Tensor.Dtype)
31+ {
32+ case DataType.DtDouble:
33+ var proto = (double*)Marshal.AllocHGlobal(sizeof(double));
34+ *proto = attr.Value.Tensor.DoubleVal[0];
35+ c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: (UIntPtr)sizeof(double), status: status);
36+ break;
37+ }
38+ }
39+ else
40+ {
41+ //c_api.TF_SetAttrValueProto(op_desc, attr.Key, null, proto_len: UIntPtr.Zero, status: status);
42+ }
43+ } */
2444
25- switch ( inputs )
45+ foreach ( var attr in node_def . Attr )
2646 {
27- case double value :
28- var v = ( double * ) Marshal . AllocHGlobal ( sizeof ( double ) ) ;
29- * v = value ;
30- tensor = c_api . TF_NewTensor ( TF_DataType . DtDouble , 0 , 0 , data : ( IntPtr ) v , len : ( UIntPtr ) sizeof ( double ) , deallocator : Tensorflow . FreeTensorDataDelegate , deallocator_arg : IntPtr . Zero ) ;
31- c_api . TF_SetAttrType ( op_desc , "dtype" , TF_DataType . DtDouble ) ;
32- break ;
47+ if ( attr . Value . Tensor == null ) continue ;
48+ switch ( attr . Value . Tensor . Dtype )
49+ {
50+ case DataType . DtDouble :
51+ var v = ( double * ) Marshal . AllocHGlobal ( sizeof ( double ) ) ;
52+ * v = attr . Value . Tensor . DoubleVal [ 0 ] ;
53+ var tensor = c_api . TF_NewTensor ( TF_DataType . DtDouble , 0 , 0 , data : ( IntPtr ) v , len : ( UIntPtr ) sizeof ( double ) , deallocator : Tensorflow . FreeTensorDataDelegate , deallocator_arg : IntPtr . Zero ) ;
54+ c_api . TF_SetAttrTensor ( op_desc , "value" , tensor , status ) ;
55+ c_api . TF_SetAttrType ( op_desc , "dtype" , TF_DataType . DtDouble ) ;
56+ break ;
57+ case DataType . DtString :
58+
59+ var proto = Marshal . StringToHGlobalAnsi ( attr . Value . Tensor . StringVal [ 0 ] . ToStringUtf8 ( ) ) ;
60+ c_api . TF_SetAttrValueProto ( op_desc , attr . Key , proto . ToPointer ( ) , proto_len : ( UIntPtr ) 32 , status : status ) ;
61+ break ;
62+ }
3363 }
3464
35- c_api . TF_SetAttrTensor ( op_desc , "value" , tensor , status ) ;
36-
3765 var c_op = c_api . TF_FinishOperation ( op_desc , status ) ;
3866
3967 return c_op ;
4068 }
69+
70+ public static NodeDef _NodeDef ( string op_type , string name , string device = "" , Dictionary < string , AttrValue > attrs = null )
71+ {
72+ var node_def = new node_def_pb2 . NodeDef ( ) ;
73+ node_def . Op = op_type ;
74+ node_def . Name = name ;
75+
76+ foreach ( var attr in attrs )
77+ {
78+ node_def . Attr . Add ( attr . Key , attr . Value ) ;
79+ }
80+
81+ return node_def ;
82+ }
83+
84+ public static int uid ( )
85+ {
86+ return 1 ;
87+ }
4188 }
4289}
0 commit comments