Skip to content

Commit 8adb34d

Browse files
committed
init session.run
1 parent d392e9d commit 8adb34d

6 files changed

Lines changed: 65 additions & 4 deletions

File tree

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace TensorFlowNET.Core
6+
{
7+
public class BaseSession
8+
{
9+
private Graph _graph;
10+
private bool _opened;
11+
private bool _closed;
12+
private int _current_version;
13+
private byte[] _target;
14+
private IntPtr _session;
15+
16+
public BaseSession(string target = "", Graph graph = null)
17+
{
18+
if(graph is null)
19+
{
20+
_graph = ops.get_default_graph();
21+
}
22+
else
23+
{
24+
_graph = graph;
25+
}
26+
27+
_target = UTF8Encoding.UTF8.GetBytes(target);
28+
var opts = c_api.TF_NewSessionOptions();
29+
var status = new Status();
30+
_session = c_api.TF_NewSession(_graph.Handle, opts, status.Handle);
31+
32+
c_api.TF_DeleteSessionOptions(opts);
33+
}
34+
35+
public virtual byte[] run(Tensor fetches)
36+
{
37+
return null;
38+
}
39+
}
40+
}

src/TensorFlowNET.Core/Graph.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ namespace TensorFlowNET.Core
1616
/// </summary>
1717
public class Graph
1818
{
19-
public IntPtr handle;
19+
private IntPtr _c_graph;
20+
public IntPtr Handle => _c_graph;
2021
private Dictionary<int, Operation> _nodes_by_id;
2122
private Dictionary<string, Operation> _nodes_by_name;
2223
private Dictionary<string, int> _names_in_use;
@@ -25,7 +26,7 @@ public class Graph
2526

2627
public Graph(IntPtr graph)
2728
{
28-
this.handle = graph;
29+
this._c_graph = graph;
2930
_nodes_by_id = new Dictionary<int, Operation>();
3031
_nodes_by_name = new Dictionary<string, Operation>();
3132
_names_in_use = new Dictionary<string, int>();

src/TensorFlowNET.Core/Session.cs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,13 @@
44

55
namespace TensorFlowNET.Core
66
{
7-
public class Session
7+
public class Session : BaseSession
88
{
9+
public override byte[] run(Tensor fetches)
10+
{
11+
var ret = base.run(fetches);
12+
13+
return ret;
14+
}
915
}
1016
}

src/TensorFlowNET.Core/c_api.cs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
using TF_Operation = System.IntPtr;
1010
using TF_Status = System.IntPtr;
1111
using TF_Tensor = System.IntPtr;
12+
using TF_Session = System.IntPtr;
13+
using TF_SessionOptions = System.IntPtr;
1214

1315
using TF_DataType = Tensorflow.DataType;
1416
using Tensorflow;
@@ -20,6 +22,9 @@ public static class c_api
2022
{
2123
public const string TensorFlowLibName = "tensorflow";
2224

25+
[DllImport(TensorFlowLibName)]
26+
public static unsafe extern void TF_DeleteSessionOptions(TF_SessionOptions opts);
27+
2328
[DllImport(TensorFlowLibName)]
2429
public static unsafe extern TF_Operation TF_FinishOperation(TF_OperationDescription desc, TF_Status status);
2530

@@ -53,6 +58,12 @@ public static class c_api
5358
[DllImport(TensorFlowLibName)]
5459
public static extern unsafe void TF_SetAttrType(TF_OperationDescription desc, string attr_name, TF_DataType value);
5560

61+
[DllImport(TensorFlowLibName)]
62+
public static extern TF_Session TF_NewSession(TF_Graph graph, TF_SessionOptions opts, TF_Status status);
63+
64+
[DllImport(TensorFlowLibName)]
65+
public static extern TF_SessionOptions TF_NewSessionOptions();
66+
5667
[DllImport(TensorFlowLibName)]
5768
public static unsafe extern IntPtr TF_Version();
5869
}

src/TensorFlowNET.Core/ops.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ public static Graph get_default_graph()
2020

2121
public static unsafe IntPtr _create_c_op(Graph graph, NodeDef node_def, object inputs)
2222
{
23-
var op_desc = c_api.TF_NewOperation(graph.handle, node_def.Op, node_def.Name);
23+
var op_desc = c_api.TF_NewOperation(graph.Handle, node_def.Op, node_def.Name);
2424
var status = new Status();
2525

2626
foreach (var attr in node_def.Attr)

test/TensorFlowNET.Examples/HelloWorld.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ of the Constant op.*/
2222

2323
// Start tf session
2424
var sess = tf.Session();
25+
26+
// Run the op
27+
sess.run(hello);
2528
}
2629
}
2730
}

0 commit comments

Comments
 (0)