forked from SciSharp/TensorFlow.NET
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathCApiTest.cs
More file actions
218 lines (157 loc) · 8.59 KB
/
CApiTest.cs
File metadata and controls
218 lines (157 loc) · 8.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using Tensorflow;
using Buffer = System.Buffer;
namespace TensorFlowNET.UnitTest
{
public class CApiTest
{
protected TF_Code TF_OK = TF_Code.TF_OK;
protected TF_DataType TF_FLOAT = TF_DataType.TF_FLOAT;
protected TF_DataType TF_BOOL = TF_DataType.TF_BOOL;
protected void EXPECT_TRUE(bool expected, string msg = "")
=> Assert.IsTrue(expected, msg);
protected void EXPECT_EQ(object expected, object actual, string msg = "")
=> Assert.AreEqual(expected, actual, msg);
protected void CHECK_EQ(object expected, object actual, string msg = "")
=> Assert.AreEqual(expected, actual, msg);
protected void EXPECT_NE(object expected, object actual, string msg = "")
=> Assert.AreNotEqual(expected, actual, msg);
protected void CHECK_NE(object expected, object actual, string msg = "")
=> Assert.AreNotEqual(expected, actual, msg);
protected void EXPECT_GE(int expected, int actual, string msg = "")
=> Assert.IsTrue(expected >= actual, msg);
protected void ASSERT_EQ(object expected, object actual, string msg = "")
=> Assert.AreEqual(expected, actual, msg);
protected void ASSERT_TRUE(bool condition, string msg = "")
=> Assert.IsTrue(condition, msg);
protected OperationDescription TF_NewOperation(Graph graph, string opType, string opName)
=> c_api.TF_NewOperation(graph, opType, opName);
protected void TF_AddInput(OperationDescription desc, TF_Output input)
=> c_api.TF_AddInput(desc, input);
protected Operation TF_FinishOperation(OperationDescription desc, Status s)
=> c_api.TF_FinishOperation(desc, s);
protected void TF_SetAttrTensor(OperationDescription desc, string attrName, Tensor value, Status s)
=> c_api.TF_SetAttrTensor(desc, attrName, value, s);
protected void TF_SetAttrType(OperationDescription desc, string attrName, TF_DataType dtype)
=> c_api.TF_SetAttrType(desc, attrName, dtype);
protected void TF_SetAttrBool(OperationDescription desc, string attrName, bool value)
=> c_api.TF_SetAttrBool(desc, attrName, value);
protected TF_DataType TFE_TensorHandleDataType(IntPtr h)
=> c_api.TFE_TensorHandleDataType(h);
protected int TFE_TensorHandleNumDims(IntPtr h, IntPtr status)
=> c_api.TFE_TensorHandleNumDims(h, status);
protected TF_Code TF_GetCode(Status s)
=> s.Code;
protected TF_Code TF_GetCode(IntPtr s)
=> c_api.TF_GetCode(s);
protected string TF_Message(IntPtr s)
=> c_api.StringPiece(c_api.TF_Message(s));
protected IntPtr TF_NewStatus()
=> c_api.TF_NewStatus();
protected void TF_DeleteStatus(IntPtr s)
=> c_api.TF_DeleteStatus(s);
protected void TF_DeleteTensor(IntPtr t)
=> c_api.TF_DeleteTensor(t);
protected IntPtr TF_TensorData(IntPtr t)
=> c_api.TF_TensorData(t);
protected ulong TF_TensorByteSize(IntPtr t)
=> c_api.TF_TensorByteSize(t);
protected void TFE_OpAddInput(IntPtr op, IntPtr h, IntPtr status)
=> c_api.TFE_OpAddInput(op, h, status);
protected void TFE_OpSetAttrType(IntPtr op, string attr_name, TF_DataType value)
=> c_api.TFE_OpSetAttrType(op, attr_name, value);
protected void TFE_OpSetAttrShape(IntPtr op, string attr_name, long[] dims, int num_dims, IntPtr out_status)
=> c_api.TFE_OpSetAttrShape(op, attr_name, dims, num_dims, out_status);
protected void TFE_OpSetAttrString(IntPtr op, string attr_name, string value, uint length)
=> c_api.TFE_OpSetAttrString(op, attr_name, value, length);
protected IntPtr TFE_NewOp(IntPtr ctx, string op_or_function_name, IntPtr status)
=> c_api.TFE_NewOp(ctx, op_or_function_name, status);
protected IntPtr TFE_NewTensorHandle(IntPtr t, IntPtr status)
=> c_api.TFE_NewTensorHandle(t, status);
protected void TFE_Execute(IntPtr op, IntPtr[] retvals, ref int num_retvals, IntPtr status)
=> c_api.TFE_Execute(op, retvals, ref num_retvals, status);
protected IntPtr TFE_NewContextOptions()
=> c_api.TFE_NewContextOptions();
protected void TFE_DeleteContext(IntPtr t)
=> c_api.TFE_DeleteContext(t);
protected IntPtr TFE_NewContext(IntPtr opts, IntPtr status)
=> c_api.TFE_NewContext(opts, status);
protected void TFE_DeleteContextOptions(IntPtr opts)
=> c_api.TFE_DeleteContextOptions(opts);
protected int TFE_OpGetInputLength(IntPtr op, string input_name, IntPtr status)
=> c_api.TFE_OpGetInputLength(op, input_name, status);
protected int TFE_OpAddInputList(IntPtr op, IntPtr[] inputs, int num_inputs, IntPtr status)
=> c_api.TFE_OpAddInputList(op, inputs, num_inputs, status);
protected int TFE_OpGetOutputLength(IntPtr op, string input_name, IntPtr status)
=> c_api.TFE_OpGetOutputLength(op, input_name, status);
protected void TFE_DeleteTensorHandle(IntPtr h)
=> c_api.TFE_DeleteTensorHandle(h);
protected void TFE_DeleteOp(IntPtr op)
=> c_api.TFE_DeleteOp(op);
protected void TFE_DeleteExecutor(IntPtr executor)
=> c_api.TFE_DeleteExecutor(executor);
protected IntPtr TFE_ContextGetExecutorForThread(IntPtr ctx)
=> c_api.TFE_ContextGetExecutorForThread(ctx);
protected void TFE_ExecutorWaitForAllPendingNodes(IntPtr executor, IntPtr status)
=> c_api.TFE_ExecutorWaitForAllPendingNodes(executor, status);
protected IntPtr TFE_TensorHandleResolve(IntPtr h, IntPtr status)
=> c_api.TFE_TensorHandleResolve(h, status);
protected string TFE_TensorHandleDeviceName(IntPtr h, IntPtr status)
=> c_api.StringPiece(c_api.TFE_TensorHandleDeviceName(h, status));
protected string TFE_TensorHandleBackingDeviceName(IntPtr h, IntPtr status)
=> c_api.StringPiece(c_api.TFE_TensorHandleBackingDeviceName(h, status));
protected IntPtr TFE_ContextListDevices(IntPtr ctx, IntPtr status)
=> c_api.TFE_ContextListDevices(ctx, status);
protected int TF_DeviceListCount(IntPtr list)
=> c_api.TF_DeviceListCount(list);
protected string TF_DeviceListType(IntPtr list, int index, IntPtr status)
=> c_api.StringPiece(c_api.TF_DeviceListType(list, index, status));
protected string TF_DeviceListName(IntPtr list, int index, IntPtr status)
=> c_api.StringPiece(c_api.TF_DeviceListName(list, index, status));
protected void TF_DeleteDeviceList(IntPtr list)
=> c_api.TF_DeleteDeviceList(list);
protected IntPtr TFE_TensorHandleCopyToDevice(IntPtr h, IntPtr ctx, string device_name, IntPtr status)
=> c_api.TFE_TensorHandleCopyToDevice(h, ctx, device_name, status);
protected void TFE_OpSetDevice(IntPtr op, string device_name, IntPtr status)
=> c_api.TFE_OpSetDevice(op, device_name, status);
protected unsafe void memcpy<T>(T* dst, void* src, ulong size)
where T : unmanaged
{
Buffer.MemoryCopy(src, dst, size, size);
}
protected unsafe void memcpy<T>(void* dst, T* src, ulong size)
where T : unmanaged
{
Buffer.MemoryCopy(src, dst, size, size);
}
protected unsafe void memcpy(void * dst, IntPtr src, ulong size)
{
Buffer.MemoryCopy(src.ToPointer(), dst, size, size);
}
protected unsafe void memcpy<T>(T[] dst, IntPtr src, ulong size)
where T : unmanaged
{
fixed (void* p = &dst[0])
Buffer.MemoryCopy(src.ToPointer(), p, size, size);
}
protected unsafe void memcpy<T>(T[] dst, IntPtr src, long size)
where T : unmanaged
{
fixed (void* p = &dst[0])
Buffer.MemoryCopy(src.ToPointer(), p, size, size);
}
protected unsafe void memcpy<T>(IntPtr dst, T[] src, ulong size)
where T : unmanaged
{
fixed (void* p = &src[0])
Buffer.MemoryCopy(p, dst.ToPointer(), size, size);
}
protected unsafe void memcpy<T>(IntPtr dst, T[] src, long size)
where T: unmanaged
{
fixed (void* p = &src[0])
Buffer.MemoryCopy(p, dst.ToPointer(), size, size);
}
}
}