forked from SciSharp/TensorFlow.NET
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathExecute.cs
More file actions
78 lines (70 loc) · 2.73 KB
/
Execute.cs
File metadata and controls
78 lines (70 loc) · 2.73 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
using System.Collections.Generic;
using System;
using System.Linq;
using static Tensorflow.Binding;
namespace Tensorflow.Eager
{
public class Execute
{
/// <summary>
/// Execute a TensorFlow operation.
/// </summary>
/// <param name="op_name">
/// Name of the TensorFlow operation (see REGISTER_OP in C++ code) to
/// execute.
/// </param>
/// <param name="num_outputs">
/// The number of outputs of the operation to fetch.
/// </param>
/// <param name="inputs">
/// A list of inputs to the operation. Each entry should be a Tensor, or
/// a value which can be passed to the Tensor constructor to create one.
/// </param>
/// <param name="attrs">
/// A tuple with alternating string attr names and attr values for this
/// operation.
/// </param>
/// <param name="ctx">The value of context.context().</param>
/// <param name="name">Customized name for the operation.</param>
/// <returns>List of output Tensor objects. The list is empty if there are no outputs</returns>
public Tensor[] execute(Context ctx, string op_name, int num_outputs,
Tensor[] inputs, object[] attrs,
string name = null)
{
ctx.ensure_initialized();
var results = tf.Runner.TFE_Execute(ctx,
ctx.device_name,
op_name,
inputs,
attrs,
num_outputs);
return results;
}
public (TF_DataType, EagerTensor[]) args_to_matching_eager(Context ctx, TF_DataType default_dtype = TF_DataType.DtInvalid, object[] args = null)
{
if (args.Length == 0 && default_dtype != TF_DataType.DtInvalid)
return (default_dtype, null);
if (args.Count(x => x is EagerTensor) == args.Length)
return ((args[0] as EagerTensor).dtype, args.Select(x => x as EagerTensor).ToArray());
var dtype = TF_DataType.DtInvalid;
foreach (var x in args)
{
if (x is EagerTensor et)
dtype = et.dtype;
}
if (dtype == TF_DataType.DtInvalid)
{
var ret = new List<EagerTensor>();
foreach (var t in args)
{
ret.Add(ops.convert_to_tensor(t, dtype, preferred_dtype: default_dtype, ctx: ctx) as EagerTensor);
if (dtype == TF_DataType.DtInvalid)
dtype = ret.Last().dtype;
}
return (dtype, ret.ToArray());
}
else
throw new NotImplementedException("");
}
}
}