forked from SciSharp/TensorFlow.NET
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathFuncGraph.cs
More file actions
69 lines (60 loc) · 2.03 KB
/
FuncGraph.cs
File metadata and controls
69 lines (60 loc) · 2.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
using System;
using System.Collections.Generic;
using System.Linq;
using static Tensorflow.Binding;
namespace Tensorflow.Graphs
{
/// <summary>
/// Graph representing a function body.
/// </summary>
public class FuncGraph : Graph
{
List<Operation> inputs;
List<Operation> outputs;
Graph outer_graph;
string func_name;
IntPtr func_handle;
public string FuncName => func_name;
/// <summary>
/// Construct a new FuncGraph.
/// </summary>
public FuncGraph(string name) : base()
{
outer_graph = ops.get_default_graph();
func_name = name;
tf.Context.graph_mode();
as_default();
}
public IntPtr ToGraph(Operation[] opers,
Operation[] inputs, Operation[] outputs,
string[] output_names)
{
using var status = new Status();
func_handle = c_api.TF_GraphToFunction(_handle,
func_name,
false,
opers.Length,
opers.Select(x => (IntPtr)x).ToArray(),
inputs.Length,
inputs.Select(x => new TF_Output(x, 0)).ToArray(),
outputs.Length,
outputs.Select(x => new TF_Output(x, 0)).ToArray(),
output_names == null || output_names.Length == 0 ? null : output_names,
IntPtr.Zero,
null,
status.Handle);
status.Check(true);
c_api.TF_GraphCopyFunction(outer_graph, func_handle, IntPtr.Zero, status.Handle);
status.Check(true);
c_api.TFE_ContextAddFunction(tf.Context.Handle, func_handle, status.Handle);
status.Check(true);
func_name = c_api.StringPiece(c_api.TF_FunctionName(func_handle));
return func_handle;
}
protected override void DisposeManagedResources()
{
base.DisposeManagedResources();
tf.Context.restore_mode();
}
}
}