forked from SciSharp/TensorFlow.NET
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathBaseSession.cs
More file actions
139 lines (113 loc) · 4.31 KB
/
BaseSession.cs
File metadata and controls
139 lines (113 loc) · 4.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
using NumSharp.Core;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices;
using System.Text;
namespace Tensorflow
{
public class BaseSession : IDisposable
{
private Graph _graph;
private bool _opened;
private bool _closed;
private int _current_version;
private byte[] _target;
private IntPtr _session;
public BaseSession(string target = "", Graph graph = null)
{
if(graph is null)
{
_graph = ops.get_default_graph();
}
else
{
_graph = graph;
}
_target = UTF8Encoding.UTF8.GetBytes(target);
var opts = c_api.TF_NewSessionOptions();
var status = new Status();
_session = c_api.TF_NewSession(_graph.Handle, opts, status.Handle);
c_api.TF_DeleteSessionOptions(opts);
}
public void Dispose()
{
}
public virtual object run(Tensor fetches, Dictionary<Tensor, object> feed_dict = null)
{
var result = _run(fetches, feed_dict);
return result;
}
private unsafe object _run(Tensor fetches, Dictionary<Tensor, object> feed_dict = null)
{
var feed_dict_tensor = new Dictionary<Tensor, NDArray>();
if (feed_dict != null)
{
NDArray np_val = null;
foreach (var feed in feed_dict)
{
switch (feed.Value)
{
case float value:
np_val = np.asarray(value);
break;
}
feed_dict_tensor[feed.Key] = np_val;
}
}
// Create a fetch handler to take care of the structure of fetches.
var fetch_handler = new _FetchHandler(_graph, fetches);
// Run request and get response.
// We need to keep the returned movers alive for the following _do_run().
// These movers are no longer needed when _do_run() completes, and
// are deleted when `movers` goes out of scope when this _run() ends.
var _ = _update_with_movers();
var final_fetches = fetch_handler.fetches();
var final_targets = fetch_handler.targets();
// We only want to really perform the run if fetches or targets are provided,
// or if the call is a partial run that specifies feeds.
var results = _do_run(final_fetches);
return fetch_handler.build_results(null, results);
}
private object[] _do_run(List<object> fetch_list)
{
var fetches = fetch_list.Select(x => (x as Tensor)._as_tf_output()).ToArray();
return _call_tf_sessionrun(fetches);
}
private unsafe object[] _call_tf_sessionrun(TF_Output[] fetch_list)
{
// Ensure any changes to the graph are reflected in the runtime.
_extend_graph();
var status = new Status();
var output_values = fetch_list.Select(x => IntPtr.Zero).ToArray();
c_api.TF_SessionRun(_session,
run_options: IntPtr.Zero,
inputs: new TF_Output[] { },
input_values: new IntPtr[] { },
ninputs: 0,
outputs: fetch_list,
output_values: output_values,
noutputs: fetch_list.Length,
target_opers: new IntPtr[] { },
ntargets: 0,
run_metadata: IntPtr.Zero,
status: status.Handle);
var result = output_values.Select(x => c_api.TF_TensorData(x))
.Select(x => (object)*(float*)x)
.ToArray();
return result;
}
/// <summary>
/// If a tensor handle that is fed to a device incompatible placeholder,
/// we move the tensor to the right device, generate a new tensor handle,
/// and update feed_dict to use the new handle.
/// </summary>
private List<object> _update_with_movers()
{
return new List<object> { };
}
private void _extend_graph()
{
}
}
}