Skip to content

Commit ce59939

Browse files
committed
fix default session
1 parent 9bc0c86 commit ce59939

6 files changed

Lines changed: 36 additions & 14 deletions

File tree

src/TensorFlowNET.Core/Sessions/BaseSession.cs

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ public virtual NDArray run<T>(T fetches, FeedItem[] feed_dict = null)
4343

4444
private NDArray _run<T>(T fetches, FeedItem[] feed_dict = null)
4545
{
46-
var feed_dict_tensor = new Dictionary<Tensor, NDArray>();
46+
var feed_dict_tensor = new Dictionary<object, object>();
4747

4848
if (feed_dict != null)
4949
feed_dict.ToList().ForEach(x => feed_dict_tensor.Add(x.Key, x.Value));
@@ -79,9 +79,30 @@ private NDArray _run<T>(T fetches, FeedItem[] feed_dict = null)
7979
/// name of an operation, the first Tensor output of that operation
8080
/// will be returned for that element.
8181
/// </returns>
82-
private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list, Dictionary<Tensor, NDArray> feed_dict)
82+
private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list, Dictionary<object, object> feed_dict)
8383
{
84-
var feeds = feed_dict.Select(x => new KeyValuePair<TF_Output, Tensor>(x.Key._as_tf_output(), new Tensor(x.Value))).ToArray();
84+
var feeds = feed_dict.Select(x =>
85+
{
86+
if(x.Key is Tensor tensor)
87+
{
88+
switch (x.Value)
89+
{
90+
case Tensor t1:
91+
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), t1);
92+
case NDArray nd:
93+
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(nd));
94+
case int intVal:
95+
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(intVal));
96+
case float floatVal:
97+
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(floatVal));
98+
case double doubleVal:
99+
return new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(doubleVal));
100+
default:
101+
break;
102+
}
103+
}
104+
throw new NotImplementedException("_do_run.feed_dict");
105+
}).ToArray();
85106
var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray();
86107
var targets = target_list;
87108

src/TensorFlowNET.Core/Sessions/FeedItem.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,13 @@ namespace Tensorflow
1010
/// </summary>
1111
public class FeedItem
1212
{
13-
public Tensor Key { get; }
14-
public NDArray Value { get; }
13+
public object Key { get; }
14+
public object Value { get; }
1515

16-
public FeedItem(Tensor tensor, NDArray nd)
16+
public FeedItem(object key, object val)
1717
{
18-
Key = tensor;
19-
Value = nd;
18+
Key = key;
19+
Value = val;
2020
}
2121
}
2222
}

src/TensorFlowNET.Core/Sessions/_FetchHandler.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ public class _FetchHandler<T>
1616
private List<Tensor> _final_fetches = new List<Tensor>();
1717
private List<T> _targets = new List<T>();
1818

19-
public _FetchHandler(Graph graph, T fetches, Dictionary<Tensor, NDArray> feeds = null, Action feed_handles = null)
19+
public _FetchHandler(Graph graph, T fetches, Dictionary<object, object> feeds = null, Action feed_handles = null)
2020
{
2121
_fetch_mapper = new _FetchMapper<T>().for_fetch(fetches);
2222
foreach(var fetch in _fetch_mapper.unique_fetches())

src/TensorFlowNET.Core/Train/Saving/Saver.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,9 +160,9 @@ public string save(Session sess,
160160

161161
if (!_is_empty)
162162
{
163-
/*model_checkpoint_path = sess.run(_saver_def.SaveTensorName, new FeedItem[] {
163+
var model_checkpoint_path1 = sess.run(_saver_def.SaveTensorName, new FeedItem[] {
164164
new FeedItem(_saver_def.FilenameTensorName, checkpoint_file)
165-
});*/
165+
});
166166
}
167167

168168
throw new NotImplementedException("");

src/TensorFlowNET.Core/ops.py.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ public static NDArray _eval_using_default_session(Tensor tensor, FeedItem[] feed
289289
/// <returns>The default `Session` being used in the current thread.</returns>
290290
public static Session get_default_session()
291291
{
292-
return tf.Session();
292+
return tf.defaultSession;
293293
}
294294

295295
public static void _run_using_default_session(Operation operation, FeedItem[] feed_dict, Graph graph, Session session)

src/TensorFlowNET.Core/tf.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ public static partial class tf
1717
public static Context context = new Context(new ContextOptions(), new Status());
1818

1919
public static Graph g = new Graph();
20-
public static Session session = new Session();
20+
public static Session defaultSession;
2121

2222
public static RefVariable Variable<T>(T data, string name = "", TF_DataType dtype = TF_DataType.DtInvalid)
2323
{
@@ -49,7 +49,8 @@ public static Graph Graph()
4949

5050
public static Session Session()
5151
{
52-
return session;
52+
defaultSession = new Session();
53+
return defaultSession;
5354
}
5455
}
5556
}

0 commit comments

Comments
 (0)