forked from SciSharp/TensorFlow.NET
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathops.cs
More file actions
85 lines (70 loc) · 2.33 KB
/
ops.cs
File metadata and controls
85 lines (70 loc) · 2.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
using System;
using System.Collections.Generic;
using System.Runtime.InteropServices;
using System.Text;
using System.Threading;
using Tensorflow;
using node_def_pb2 = Tensorflow;
using Google.Protobuf;
namespace Tensorflow
{
public static class ops
{
public static Graph get_default_graph()
{
return tf.Graph();
}
public static unsafe IntPtr _create_c_op(Graph graph, NodeDef node_def, List<Tensor> inputs)
{
var op_desc = c_api.TF_NewOperation(graph.Handle, node_def.Op, node_def.Name);
// Add inputs
if(inputs != null)
{
foreach (var op_input in inputs)
{
c_api.TF_AddInput(op_desc, op_input._as_tf_output());
}
}
var status = new Status();
// Add control inputs
// Add attrs
foreach (var attr in node_def.Attr)
{
var bytes = attr.Value.ToByteArray();
var proto = Marshal.AllocHGlobal(bytes.Length);
Marshal.Copy(bytes, 0, proto, bytes.Length);
c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: (UIntPtr)bytes.Length, status: status.Handle);
if(status.Code != TF_Code.TF_OK) throw new Exception(status.Message);
}
var c_op = c_api.TF_FinishOperation(op_desc, status.Handle);
if (status.Code != TF_Code.TF_OK) throw new Exception(status.Message);
return c_op;
}
public static NodeDef _NodeDef(string op_type, string name, string device = "", Dictionary<string, AttrValue> attrs = null)
{
var node_def = new node_def_pb2.NodeDef();
node_def.Op = op_type;
node_def.Name = name;
foreach (var attr in attrs)
{
node_def.Attr.Add(attr.Key, attr.Value);
}
return node_def;
}
public static string _name_from_scope_name(string name)
{
if (name.EndsWith("/"))
{
return name.Substring(0, name.Length - 1);
}
else
{
return name;
}
}
public static int uid()
{
return 1;
}
}
}