Skip to content

Commit 2995c36

Browse files
committed
sess.run not finished yet
1 parent bc13c9a commit 2995c36

11 files changed

Lines changed: 117 additions & 36 deletions

File tree

src/TensorFlowNET.Core/BaseSession.cs

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
namespace Tensorflow
66
{
7-
public class BaseSession
7+
public class BaseSession : IDisposable
88
{
99
private Graph _graph;
1010
private bool _opened;
@@ -32,18 +32,23 @@ public BaseSession(string target = "", Graph graph = null)
3232
c_api.TF_DeleteSessionOptions(opts);
3333
}
3434

35-
public virtual byte[] run(Tensor fetches)
35+
public void Dispose()
3636
{
37-
return _run(fetches);
37+
3838
}
3939

40-
private unsafe byte[] _run(Tensor fetches)
40+
public virtual byte[] run(Tensor fetches, Dictionary<Tensor, object> feed_dict = null)
41+
{
42+
return _run(fetches, feed_dict);
43+
}
44+
45+
private unsafe byte[] _run(Tensor fetches, Dictionary<Tensor, object> feed_dict = null)
4146
{
4247
var status = new Status();
4348

4449
c_api.TF_SessionRun(_session,
4550
run_options: null,
46-
inputs: new TF_Input[] { },
51+
inputs: new TF_Output[] { },
4752
input_values: new IntPtr[] { },
4853
ninputs: 1,
4954
outputs: new TF_Output[] { },

src/TensorFlowNET.Core/Graph.cs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ public Graph(IntPtr graph)
3131
_names_in_use = new Dictionary<string, int>();
3232
}
3333

34-
public unsafe Operation create_op(string op_type, object inputs, TF_DataType[] dtypes,
34+
public unsafe Operation create_op(string op_type, List<Tensor> inputs, TF_DataType[] dtypes,
3535
TF_DataType[] input_types = null, string name = "",
3636
Dictionary<string, AttrValue> attrs = null, OpDef op_def = null)
3737
{
@@ -43,9 +43,13 @@ public unsafe Operation create_op(string op_type, object inputs, TF_DataType[] d
4343
name = name.EndsWith("/") ? ops._name_from_scope_name(name) : unique_name(name);
4444
var node_def = ops._NodeDef(op_type, name, device: "", attrs: attrs);
4545

46-
var op = new Operation(node_def, this,
46+
var op = new Operation(node_def,
47+
this,
4748
inputs: inputs,
4849
output_types: dtypes,
50+
control_inputs: new object[] { },
51+
input_types: input_types,
52+
original_op: null,
4953
op_def: op_def);
5054

5155
return op;
@@ -73,6 +77,7 @@ public string unique_name(string name)
7377
else
7478
{
7579
_names_in_use[name_key] = 1;
80+
return name;
7681
}
7782

7883

src/TensorFlowNET.Core/OpDefLibrary.cs

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -47,17 +47,11 @@ public unsafe Operation _apply_op_helper(string op_type_name, string name = "",
4747
}
4848

4949
var attrs = new Dictionary<string, object>();
50+
51+
// Perform input type inference
5052
var inputs = new List<Tensor>();
5153
var input_types = new List<DataType>();
52-
53-
foreach (var attr in op_def.Attr)
54-
{
55-
if (keywords.ContainsKey(attr.Name))
56-
{
57-
attrs[attr.Name] = keywords[attr.Name];
58-
}
59-
}
60-
54+
6155
foreach (var input_arg in op_def.InputArg)
6256
{
6357
var input_name = input_arg.Name;
@@ -70,18 +64,38 @@ public unsafe Operation _apply_op_helper(string op_type_name, string name = "",
7064
{
7165
attrs[input_arg.TypeAttr] = DataType.DtFloat;
7266
}
67+
68+
if (input_arg.IsRef)
69+
{
70+
71+
}
72+
else
73+
{
74+
input_types.Add((keywords[input_name] as Tensor).dtype);
75+
}
7376
}
7477

78+
// Process remaining attrs
79+
foreach (var attr in op_def.Attr)
80+
{
81+
if (keywords.ContainsKey(attr.Name))
82+
{
83+
attrs[attr.Name] = keywords[attr.Name];
84+
}
85+
}
86+
87+
// Convert attr values to AttrValue protos.
7588
var attr_protos = new Dictionary<string, AttrValue>();
7689
foreach (var attr_def in op_def.Attr)
7790
{
7891
var key = attr_def.Name;
92+
var value = attrs[key];
7993
var attr_value = new AttrValue();
8094

8195
switch (attr_def.Type)
8296
{
8397
case "type":
84-
attr_value.Type = (DataType)keywords["dtype"];
98+
attr_value.Type = _MakeType(value, attr_def);
8599
break;
86100
case "shape":
87101
attr_value.Shape = new TensorShapeProto();
@@ -91,6 +105,7 @@ public unsafe Operation _apply_op_helper(string op_type_name, string name = "",
91105
attr_protos[key] = attr_value;
92106
}
93107

108+
// Determine output types (possibly using attrs)
94109
var output_types = new List<DataType>();
95110

96111
foreach (var arg in op_def.OutputArg)
@@ -105,6 +120,7 @@ public unsafe Operation _apply_op_helper(string op_type_name, string name = "",
105120
}
106121
}
107122

123+
// Add Op to graph
108124
var op = g.create_op(op_type_name, inputs, output_types.ToArray(),
109125
name: scope,
110126
input_types: input_types.ToArray(),
@@ -113,5 +129,10 @@ public unsafe Operation _apply_op_helper(string op_type_name, string name = "",
113129

114130
return op;
115131
}
132+
133+
public DataType _MakeType(Object v, AttrDef attr_def)
134+
{
135+
return DataType.DtFloat;
136+
}
116137
}
117138
}

src/TensorFlowNET.Core/Operation.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ namespace Tensorflow
88
public class Operation
99
{
1010
private Graph _graph;
11-
private IntPtr _c_op;
11+
public IntPtr _c_op;
1212
public int _id => _id_value;
1313
private int _id_value;
1414
public string name;
@@ -27,7 +27,7 @@ public Operation(Graph g, string opType, string oper_name)
2727
c_api.TF_FinishOperation(desc, status.Handle);
2828
}
2929

30-
public Operation(NodeDef node_def, Graph g, object inputs = null, TF_DataType[] output_types = null, object control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null)
30+
public Operation(NodeDef node_def, Graph g, List<Tensor> inputs = null, TF_DataType[] output_types = null, object control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null)
3131
{
3232
_graph = g;
3333

@@ -38,7 +38,7 @@ public Operation(NodeDef node_def, Graph g, object inputs = null, TF_DataType[]
3838
_outputs = new Tensor[num_outputs];
3939
for (int i = 0; i < num_outputs; i++)
4040
{
41-
_outputs[i] = new Tensor(this, i, TF_DataType.DtDouble);
41+
_outputs[i] = new Tensor(this, i, TF_DataType.DtFloat);
4242
}
4343

4444
_graph._add_op(this);

src/TensorFlowNET.Core/Session.cs

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,5 @@ namespace Tensorflow
66
{
77
public class Session : BaseSession
88
{
9-
public override byte[] run(Tensor fetches)
10-
{
11-
var ret = base.run(fetches);
12-
13-
return ret;
14-
}
159
}
1610
}

src/TensorFlowNET.Core/Tensor.cs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,23 @@ namespace Tensorflow
66
{
77
public class Tensor
88
{
9-
private Operation _op;
10-
private int _value_index;
9+
private readonly Operation _op;
10+
public Operation op => _op;
11+
private readonly int _value_index;
12+
public int value_index => _value_index;
1113
private DataType _dtype;
14+
public DataType dtype => _dtype;
1215

1316
public Tensor(Operation op, int value_index, DataType dtype)
1417
{
1518
_op = op;
1619
_value_index = value_index;
1720
_dtype = dtype;
1821
}
22+
23+
public TF_Output _as_tf_output()
24+
{
25+
return c_api_util.tf_output(_op._c_op, _value_index);
26+
}
1927
}
2028
}

src/TensorFlowNET.Core/c_api.cs

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,14 @@ public static class c_api
1919
{
2020
public const string TensorFlowLibName = "tensorflow";
2121

22+
/// <summary>
23+
/// For inputs that take a single tensor.
24+
/// </summary>
25+
/// <param name="desc"></param>
26+
/// <param name="input"></param>
27+
[DllImport(TensorFlowLibName)]
28+
public static unsafe extern void TF_AddInput(TF_OperationDescription desc, TF_Output input);
29+
2230
[DllImport(TensorFlowLibName)]
2331
public static unsafe extern void TF_DeleteSessionOptions(TF_SessionOptions opts);
2432

@@ -60,11 +68,11 @@ public static class c_api
6068

6169
[DllImport(TensorFlowLibName)]
6270
public static extern unsafe void TF_SessionRun(TF_Session session, TF_Buffer* run_options,
63-
TF_Input[] inputs, TF_Tensor[] input_values,
64-
int ninputs, TF_Output[] outputs,
65-
TF_Tensor[] output_values, int noutputs,
71+
TF_Output[] inputs, TF_Tensor[] input_values, int ninputs,
72+
TF_Output[] outputs, TF_Tensor[] output_values, int noutputs,
6673
TF_Operation[] target_opers, int ntargets,
67-
TF_Buffer* run_metadata, TF_Status status);
74+
TF_Buffer* run_metadata,
75+
TF_Status status);
6876

6977
[DllImport(TensorFlowLibName)]
7078
public static extern unsafe void TF_SetAttrType(TF_OperationDescription desc, string attr_name, TF_DataType value);
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow
6+
{
7+
public class c_api_util
8+
{
9+
public static TF_Output tf_output(IntPtr c_op, int index)
10+
{
11+
var ret = new TF_Output();
12+
ret.oper = c_op;
13+
ret.index = index;
14+
15+
return ret;
16+
}
17+
}
18+
}

src/TensorFlowNET.Core/ops.cs

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,35 @@ public static Graph get_default_graph()
1616
return tf.Graph();
1717
}
1818

19-
public static unsafe IntPtr _create_c_op(Graph graph, NodeDef node_def, object inputs)
19+
public static unsafe IntPtr _create_c_op(Graph graph, NodeDef node_def, List<Tensor> inputs)
2020
{
2121
var op_desc = c_api.TF_NewOperation(graph.Handle, node_def.Op, node_def.Name);
22+
23+
// Add inputs
24+
foreach(var op_input in inputs)
25+
{
26+
c_api.TF_AddInput(op_desc, op_input._as_tf_output());
27+
}
28+
2229
var status = new Status();
2330

31+
// Add control inputs
32+
33+
// Add attrs
2434
foreach (var attr in node_def.Attr)
2535
{
2636
var bytes = attr.Value.ToByteArray();
2737
var proto = Marshal.AllocHGlobal(bytes.Length);
2838
Marshal.Copy(bytes, 0, proto, bytes.Length);
2939
c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: (UIntPtr)bytes.Length, status: status.Handle);
40+
41+
if(status.Code != TF_Code.TF_OK) throw new Exception(status.Message);
3042
}
3143

3244
var c_op = c_api.TF_FinishOperation(op_desc, status.Handle);
3345

46+
if (status.Code != TF_Code.TF_OK) throw new Exception(status.Message);
47+
3448
return c_op;
3549
}
3650

src/TensorFlowNET.Core/ops/gen_math_ops.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@ public static Tensor add(Tensor a, Tensor b)
1717

1818
var _op = _op_def_lib._apply_op_helper("Add", name: "add", keywords: keywords);
1919

20-
return null;
20+
var tensor = new Tensor(_op, 0, DataType.DtFloat);
21+
22+
return tensor;
2123
}
2224

2325
private static OpDefLibrary _InitOpDefLibrary()

0 commit comments

Comments
 (0)