forked from SciSharp/TensorFlow.NET
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathops.py.cs
More file actions
439 lines (382 loc) · 17 KB
/
ops.py.cs
File metadata and controls
439 lines (382 loc) · 17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
using System;
using System.Collections.Generic;
using System.Runtime.InteropServices;
using System.Text;
using System.Threading;
using Tensorflow;
using node_def_pb2 = Tensorflow;
using Google.Protobuf;
using System.Linq;
using NumSharp.Core;
using System.ComponentModel;
using Tensorflow.Gradients;
namespace Tensorflow
{
public partial class ops : Python
{
public static void add_to_collection<T>(string name, T value)
{
var graph = tf.get_default_graph();
graph.add_to_collection(name, value);
}
public static void add_to_collections<T>(List<string> names, T value)
{
var graph = tf.get_default_graph();
graph.add_to_collections(names, value);
}
/// <summary>
/// Wrapper for `Graph.get_collection()` using the default graph.
/// contains many standard names for collections.
/// </summary>
/// <param name="key">
/// The key for the collection. For example, the `GraphKeys` class
/// </param>
/// <param name="scope"></param>
/// <returns>
/// The list of values in the collection with the given `name`, or
/// an empty list if no value has been added to that collection. The
/// list contains the values in the order under which they were
/// collected.
/// </returns>
public static object get_collection(string key, string scope = null)
{
return get_default_graph().get_collection(key, scope);
}
public static object get_collection_ref(string key)
{
return get_default_graph().get_collection_ref(key);
}
private static Graph default_graph;
public static Graph get_default_graph()
{
if (default_graph == null)
default_graph = tf.Graph();
return default_graph;
}
public static Graph set_default_graph(Graph graph)
{
default_graph = graph;
return default_graph;
}
public static Graph _get_graph_from_inputs(List<Tensor> op_input_list, Graph graph = null)
{
foreach(var op_input in op_input_list)
{
// Determine if this is a valid graph_element.
var graph_element = op_input;
}
return get_default_graph();
}
/// <summary>
/// Converts the given `value` to a `Tensor`.
/// </summary>
/// <param name="value"></param>
/// <param name="dtype"></param>
/// <param name="name"></param>
/// <returns></returns>
public static Tensor convert_to_tensor(object value, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, TF_DataType preferred_dtype = TF_DataType.DtInvalid)
{
return convert_to_tensor_v2(value, dtype, preferred_dtype, name);
}
public static Tensor convert_to_tensor_v2(object value, TF_DataType dtype = TF_DataType.DtInvalid, TF_DataType dtype_hint = TF_DataType.DtInvalid, string name = null)
{
return internal_convert_to_tensor(value, dtype: dtype, name: name, preferred_dtype: dtype_hint, as_ref: false);
}
public static Tensor convert_to_tensor_or_composite(Tensor value, TF_DataType dtype = TF_DataType.DtInvalid, string name = null)
{
return internal_convert_to_tensor_or_composite(value: value, dtype: dtype, name: name, as_ref: false);
}
public static Tensor internal_convert_to_tensor_or_composite(Tensor value, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool as_ref = false)
{
return internal_convert_to_tensor(value, dtype: dtype, name: name, as_ref: as_ref);
}
/// <summary>
/// Wrapper for `Graph.control_dependencies()` using the default graph.
/// </summary>
/// <param name="control_inputs"></param>
public static _ControlDependenciesController control_dependencies(Operation[] control_inputs)
{
return get_default_graph().control_dependencies(control_inputs);
}
/// <summary>
/// Creates a TF_Operation.
/// </summary>
/// <param name="graph">a `Graph`.</param>
/// <param name="node_def">`node_def_pb2.NodeDef` for the operation to create.</param>
/// <param name="inputs">
/// A list of `Tensor`s (corresponding to scalar inputs) and lists of
/// `Tensor`s (corresponding to sequence inputs, e.g. "int64 * N",
/// "list(int64)"). The length of the list should be equal to the number of
/// inputs specified by this operation's op def.
/// </param>
/// <param name="control_inputs">A list of `Operation`s to set as control dependencies.</param>
/// <returns>A wrapped TF_Operation*.</returns>
public static IntPtr _create_c_op<T>(Graph graph, NodeDef node_def, T[] inputs, Operation[] control_inputs)
{
var op_desc = graph.NewOperation(node_def.Op, node_def.Name);
// Add inputs
foreach (var op_input in inputs)
{
if (op_input is Tensor[] op_inputs)
c_api.TF_AddInputList(op_desc, op_inputs.Select(x => x._as_tf_output()).ToArray(), op_inputs.Length);
else if (op_input is Tensor op_input1)
{
if (op_input1.op == null)
c_api.TF_AddInput(op_desc, new TF_Output(op_desc, 0));
else
c_api.TF_AddInput(op_desc, op_input1._as_tf_output());
}
else
throw new NotImplementedException("_create_c_op");
}
var status = new Status();
// Add control inputs
foreach (var control_input in control_inputs)
c_api.TF_AddControlInput(op_desc, control_input);
// Add attrs
foreach (var attr in node_def.Attr)
{
var bytes = attr.Value.ToByteArray();
var proto = Marshal.AllocHGlobal(bytes.Length);
Marshal.Copy(bytes, 0, proto, bytes.Length);
uint len = (uint)bytes.Length;
c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: len, status: status);
status.Check(true);
}
var c_op = c_api.TF_FinishOperation(op_desc, status);
status.Check(true);
return c_op;
}
public static OpDef _get_op_def(Graph graph, string type)
{
return graph.GetOpDef(type);
}
public static NodeDef _NodeDef(string op_type, string name, string device = "", Dictionary<string, AttrValue> attrs = null)
{
var node_def = new node_def_pb2.NodeDef();
node_def.Op = op_type;
node_def.Name = name;
foreach (var attr in attrs)
{
node_def.Attr.Add(attr.Key, attr.Value);
}
return node_def;
}
public static string _name_from_scope_name(string name)
{
if (name.EndsWith("/"))
{
return name.Substring(0, name.Length - 1);
}
else
{
return name;
}
}
/// <summary>
/// A context manager that lifts ops out of control-flow scopes and function-building graphs.
/// </summary>
/// <returns></returns>
public static void init_scope()
{
// Retrieve the active name scope: entering an `init_scope` preserves
// the name scope of the current context.
var default_graph = get_default_graph();
var scope = default_graph.get_name_scope();
if (!String.IsNullOrEmpty(scope) && !scope.EndsWith("/"))
// Names that end with trailing slashes are treated by `name_scope` as
// absolute.
scope += "/";
// inner_device_stack = default_graph._device_function_stack
// var outer_context = default_graph.as_default;
with(ops.control_dependencies(null), delegate
{
var outer_graph = get_default_graph();
// outer_device_stack = None
});
}
private static int uid_number = 0;
/// <summary>
/// A unique (within this program execution) integer.
/// Not thread safe
/// </summary>
/// <returns></returns>
public static int uid()
{
return uid_number++;
}
public static void colocate_with(Operation op, bool ignore_existing = false)
{
_colocate_with_for_gradient(op, null, ignore_existing);
}
public static void colocate_with(Tensor tensor, bool ignore_existing = false)
{
_colocate_with_for_gradient(tensor.op, null, ignore_existing);
}
public static void _colocate_with_for_gradient(Operation op, string gradient_uid, bool ignore_existing = false)
{
var default_graph = get_default_graph();
default_graph._colocate_with_for_gradient(op, gradient_uid, ignore_existing);
}
/// <summary>
/// Uses the default session to evaluate one or more tensors.
/// </summary>
/// <param name="tensors">A single Tensor, or a list of Tensor objects.</param>
/// <param name="feed_dict">
/// A dictionary that maps Tensor objects (or tensor names) to lists,
/// numpy ndarrays, TensorProtos, or strings.
/// </param>
/// <param name="graph">The graph in which the tensors are defined.</param>
/// <param name="session">A different session to use to evaluate "tensors".</param>
/// <returns>
/// Either a single numpy ndarray if "tensors" is a single tensor; or a list
/// of numpy ndarrays that each correspond to the respective element in
/// "tensors".
/// </returns>
public static NDArray _eval_using_default_session(Tensor tensor, FeedItem[] feed_dict, Graph graph, Session session = null)
{
if (session == null)
{
session = get_default_session();
if (session == null)
throw new ValueError("Cannot evaluate tensor using `eval()`: No default " +
"session is registered. Use `with " +
"sess.as_default()` or pass an explicit session to " +
"`eval(session=sess)`");
if (session.graph != graph)
throw new ValueError("Cannot use the default session to evaluate tensor: " +
"the tensor's graph is different from the session's " +
"graph. Pass an explicit session to " +
"`eval(session=sess)`.");
}
else
{
if (session.graph != graph)
throw new ValueError("Cannot use the default session to evaluate tensor: " +
"the tensor's graph is different from the session's " +
"graph. Pass an explicit session to " +
"`eval(session=sess)`.");
}
return session.run(tensor, feed_dict);
}
/// <summary>
/// Returns the default session for the current thread.
/// </summary>
/// <returns>The default `Session` being used in the current thread.</returns>
public static Session get_default_session()
{
return tf.defaultSession;
}
/// <summary>
/// Prepends name scope to a name.
/// </summary>
/// <param name="name"></param>
/// <param name="import_scope"></param>
/// <returns></returns>
public static string prepend_name_scope(string name, string import_scope)
{
if (!string.IsNullOrEmpty(import_scope))
{
if (import_scope.EndsWith("/"))
import_scope = import_scope.Substring(0, import_scope.Length - 1);
throw new NotImplementedException("prepend_name_scope");
}
else
return name;
}
public static void _run_using_default_session(Operation operation, FeedItem[] feed_dict, Graph graph, Session session)
{
if (session == null)
{
session = get_default_session();
if (session == null)
throw new ValueError("Cannot execute operation using `run()`: No default " +
"session is registered. Use `with " +
"sess.as_default():` or pass an explicit session to " +
"`run(session=sess)`");
}
if (session.graph != graph)
throw new ValueError("Cannot use the default session to execute operation: " +
"the operation's graph is different from the " +
"session's graph. Pass an explicit session to " +
"run(session=sess).");
session.run(operation, feed_dict);
}
public static Tensor[] convert_n_to_tensor(object[] values, TF_DataType dtype = TF_DataType.DtInvalid, string name = null)
=> internal_convert_n_to_tensor(values, dtype: dtype, name: name, as_ref: false);
public static Tensor[] convert_n_to_tensor_or_indexed_slices(Tensor[] values, TF_DataType dtype = TF_DataType.DtInvalid, string name = null)
=> internal_convert_n_to_tensor_or_indexed_slices(values, dtype: dtype, name: name);
public static Tensor convert_to_tensor_or_indexed_slices(Tensor value, TF_DataType dtype = TF_DataType.DtInvalid, string name = null)
=> internal_convert_to_tensor_or_indexed_slices(value: value, dtype: dtype, name: name, as_ref: false);
public static Tensor internal_convert_to_tensor_or_indexed_slices(Tensor value, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool as_ref = false)
=> value;
public static Tensor[] internal_convert_n_to_tensor_or_indexed_slices(Tensor[] values, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool as_ref = false)
{
var ret = new List<Tensor>();
foreach(var (i, value) in Python.enumerate(values))
{
if (value == null)
{
ret.Add(value);
}
else
{
var n = string.IsNullOrEmpty(name) ? "" : $"{name}_{i}";
ret.Add(internal_convert_to_tensor_or_indexed_slices(value, dtype: dtype, name: n, as_ref: as_ref));
}
}
return ret.ToArray();
}
public static Tensor[] internal_convert_n_to_tensor(object values, TF_DataType dtype = TF_DataType.DtInvalid,
string name = null, TF_DataType preferred_dtype = TF_DataType.DtInvalid,
bool as_ref = false)
{
var ret = new List<Tensor>();
foreach((int i, object value) in enumerate(values as object[]))
{
string n = string.IsNullOrEmpty(name) ? "" : $"{name}_{i}";
ret.Add(internal_convert_to_tensor(value, dtype: dtype, name: n, as_ref: as_ref, preferred_dtype: preferred_dtype));
}
return ret.ToArray();
}
public static Tensor internal_convert_to_tensor(object value, TF_DataType dtype = TF_DataType.DtInvalid,
string name = null, TF_DataType preferred_dtype = TF_DataType.DtInvalid,
bool as_ref = false,
string scope = null)
{
if (dtype == TF_DataType.DtInvalid)
dtype = preferred_dtype;
switch (value)
{
case NDArray nd:
return constant_op.constant(nd, dtype: dtype, name: name);
case Tensor tensor:
return tensor;
case Tensor[] tensors:
return array_ops._autopacking_helper(tensors, dtype, name);
case RefVariable varVal:
return varVal._TensorConversionFunction(as_ref: as_ref);
case object[] objects:
return array_ops._autopacking_conversion_function(objects, dtype: dtype, name: name);
default:
return constant_op.constant(value, dtype: dtype, name: name);
}
}
public static string strip_name_scope(string name, string export_scope = "")
{
if (!string.IsNullOrEmpty(export_scope))
{
throw new NotImplementedException("ops.strip_name_scope");
}
else
{
return name;
}
}
public static string get_name_scope()
{
var g = get_default_graph();
return g.get_name_scope();
}
}
}