forked from SciSharp/TensorFlow.NET
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathAutoGraph.cs
More file actions
71 lines (63 loc) · 2.24 KB
/
AutoGraph.cs
File metadata and controls
71 lines (63 loc) · 2.24 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Text;
using static Tensorflow.Binding;
namespace Tensorflow.Graphs
{
public class AutoGraph
{
public Func<Tensor, Tensor> to_graph(Func<Tensor, Tensor> func)
{
string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}";
// IntPtr func_handle;
using (var graph = new FuncGraph(func_name))
{
var input = tf.placeholder(tf.int32);
var output = func(input);
var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray();
var func_handle = graph.ToGraph(opers,
new Operation[] { input },
new Operation[] { output },
null);
}
return (Tensor input) =>
{
var result = tf.Runner.TFE_Execute(tf.Context,
tf.Context.DeviceName,
func_name,
new[] { input },
null,
1);
return result[0];
};
}
public Func<Tensor, Tensor, Tensor> to_graph(Func<Tensor, Tensor, Tensor> func)
{
string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}";
// IntPtr func_handle;
using(var graph = new FuncGraph(func_name))
{
var input1 = tf.placeholder(tf.int32);
var input2 = tf.placeholder(tf.int32);
var output = func(input1, input2);
var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray();
var func_handle = graph.ToGraph(opers,
new Operation[] { input1, input2 },
new Operation[] { output },
null);
}
return (Tensor a, Tensor b) =>
{
var result = tf.Runner.TFE_Execute(tf.Context,
tf.Context.DeviceName,
func_name,
new[] { a, b },
null,
1);
return result[0];
};
}
}
}