Skip to content

Commit 87f9901

Browse files
committed
Fix trainable_weights is empty when using Keras Functional model SciSharp#626
1 parent b6f155c commit 87f9901

16 files changed

Lines changed: 205 additions & 126 deletions

File tree

src/TensorFlowNET.Core/Graphs/AutoGraphAttribute.cs

Lines changed: 48 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -13,53 +13,81 @@ namespace Tensorflow.Graphs
1313
public sealed class AutoGraphAttribute : OnMethodBoundaryAspect
1414
{
1515
FuncGraph graph;
16-
Tensor[] originalInputs;
16+
Tensors originalInputs;
1717
string func_name;
18-
static Dictionary<string, Func<Tensor[], Tensor>> functions = new Dictionary<string, Func<Tensor[], Tensor>>();
18+
static Dictionary<string, Func<Tensors, Tensors>> functions = new Dictionary<string, Func<Tensors, Tensors>>();
1919

2020
public override void OnEntry(MethodExecutionArgs args)
2121
{
22-
if (args.Instance is TensorFlowOpLayer op)
23-
func_name = $"autograph_{op.OpType}.{args.Method.Name}";
24-
else
25-
func_name = $"autograph_{args.Instance}.{args.Method.Name}";
22+
func_name = $"autograph_{args.Instance.GetHashCode()}.{args.Method.Name}";
2623

2724
if (functions.ContainsKey(func_name))
2825
{
29-
args.ReturnValue = functions[func_name](args.Arguments.Select(x => x as Tensor).ToArray());
26+
if(args.Arguments[0] is Tensors tensor_inputs)
27+
args.ReturnValue = functions[func_name](tensor_inputs.ToArray());
28+
else
29+
args.ReturnValue = functions[func_name](args.Arguments.Select(x => x as Tensor).ToArray());
3030
args.FlowBehavior = FlowBehavior.Return;
3131
return;
3232
}
3333

3434
// make function as an Operation by autograph
3535
graph = new FuncGraph(func_name);
3636

37-
originalInputs = new Tensor[args.Arguments.Length];
38-
// convert args to placeholder
39-
for (var i = 0; i < args.Arguments.Length; i++)
37+
// convert to Tensors
38+
if(args.Arguments[0] is Tensors inputs)
39+
{
40+
originalInputs = inputs;
41+
var new_inputs = inputs.Select(x => tf.placeholder(x.dtype, shape: x.TensorShape)).ToArray();
42+
args.Arguments[0] = new Tensors(new_inputs);
43+
}
44+
else
4045
{
41-
if (args.Arguments[i] is EagerTensor tensor)
46+
originalInputs = new Tensors(args.Arguments.Length);
47+
// convert args to placeholder
48+
for (var i = 0; i < args.Arguments.Length; i++)
4249
{
43-
originalInputs[i] = tensor;
44-
args.Arguments[i] = tf.placeholder(tensor.dtype, shape: tensor.TensorShape);
50+
if (args.Arguments[i] is EagerTensor tensor)
51+
{
52+
originalInputs[i] = tensor;
53+
args.Arguments[i] = tf.placeholder(tensor.dtype, shape: tensor.TensorShape);
54+
}
4555
}
4656
}
4757
}
4858

4959
public override void OnExit(MethodExecutionArgs args)
5060
{
51-
var output = (Tensor)args.ReturnValue;
52-
var inputs = args.Arguments.Select(x => x as Tensor).ToArray();
5361
var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray();
5462

55-
graph.ToGraph(opers,
56-
inputs.Select(x => x.op).ToArray(),
57-
new Operation[] { output.op },
58-
null);
63+
if (args.ReturnValue is Tensors outputs)
64+
{
65+
if(args.Arguments[0] is Tensors inputs)
66+
{
67+
graph.ToGraph(opers,
68+
inputs.Select(x => x.op).ToArray(),
69+
outputs.Select(x => x.op).ToArray(),
70+
null);
71+
}
72+
else
73+
{
74+
graph.ToGraph(opers,
75+
args.Arguments.Select(x => (x as Tensor).op).ToArray(),
76+
outputs.Select(x => x.op).ToArray(),
77+
null);
78+
}
79+
}
80+
else
81+
{
82+
graph.ToGraph(opers,
83+
args.Arguments.Select(x => (x as Tensor).op).ToArray(),
84+
new Operation[] { (args.ReturnValue as Tensor).op },
85+
null);
86+
}
5987

6088
graph.Dispose();
6189

62-
Func<Tensor[], Tensor> function = (x) =>
90+
Func<Tensors, Tensors> function = (x) =>
6391
{
6492
var result = tf.Runner.TFE_Execute(tf.Context,
6593
tf.Context.DeviceName,

src/TensorFlowNET.Core/Keras/Engine/Functional.cs

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,25 @@ public class Functional : Model
2424
List<KerasHistory> _output_coordinates;
2525
public string[] NetworkNodes { get; set; }
2626
public Dictionary<int, List<Node>> NodesByDepth { get; set; }
27-
public List<Layer> Layers { get; set; }
27+
public List<Layer> Layers => _layers;
28+
2829
Dictionary<int, int> tensor_usage_count;
2930
public Dictionary<int, int> TensorUsageCount => tensor_usage_count;
3031

32+
public override List<IVariableV1> trainable_variables
33+
{
34+
get
35+
{
36+
var variables = new List<IVariableV1>();
37+
foreach(var layer in _layers)
38+
{
39+
if (layer.Trainable)
40+
variables.AddRange(layer.trainable_variables);
41+
}
42+
return variables;
43+
}
44+
}
45+
3146
public Functional(Tensors inputs, Tensors outputs)
3247
: base(new ModelArgs
3348
{
@@ -80,7 +95,7 @@ void _init_graph_network(Tensors inputs, Tensors outputs)
8095

8196
NetworkNodes = nodes;
8297
NodesByDepth = nodes_by_depth;
83-
Layers = layers;
98+
_layers = layers;
8499

85100
ComputeTensorUsageCount();
86101
}
@@ -316,11 +331,15 @@ Tensors run_internal_graph(Tensors inputs, bool training = false, Tensors mask =
316331
}
317332
}
318333

319-
foreach(var x in outputs)
320-
{
334+
var output_tensors = new List<Tensor>();
321335

336+
foreach (var x in outputs)
337+
{
338+
var x_id = x.GetHashCode();
339+
output_tensors.append(tensor_dict[x_id].Dequeue());
322340
}
323-
throw new NotImplementedException("");
341+
342+
return output_tensors;
324343
}
325344

326345
Tensor conform_to_reference_input(Tensor tensor, Tensor ref_input)

src/TensorFlowNET.Core/Keras/Engine/Layer.AddWeights.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,9 @@ protected virtual IVariableV1 add_weight(string name,
5555

5656
//backend.track_variable(variable);
5757
if (trainable == true)
58-
trainableWeights.Add(variable);
58+
trainable_weights.Add(variable);
5959
else
60-
nonTrainableWeights.Add(variable);
60+
non_trainable_weights.Add(variable);
6161

6262
return variable;
6363
}

src/TensorFlowNET.Core/Keras/Engine/Layer.cs

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -58,20 +58,13 @@ public abstract partial class Layer : AutoTrackable
5858
protected InputSpec inputSpec;
5959
bool dynamic = true;
6060
public bool SupportsMasking { get; set; }
61-
protected List<IVariableV1> trainableWeights;
62-
public List<IVariableV1> trainable_variables
63-
{
64-
get
65-
{
66-
if(trainableWeights.Count == 0)
67-
_layers.ForEach(x => trainableWeights.AddRange(x.trainableWeights));
61+
protected List<IVariableV1> trainable_weights;
6862

69-
return trainableWeights;
70-
}
71-
}
63+
public virtual List<IVariableV1> trainable_variables => trainable_weights;
64+
7265

73-
protected List<IVariableV1> nonTrainableWeights;
74-
public List<IVariableV1> non_trainable_variables => nonTrainableWeights;
66+
protected List<IVariableV1> non_trainable_weights;
67+
public List<IVariableV1> non_trainable_variables => non_trainable_weights;
7568

7669
protected string name;
7770
protected string base_name;
@@ -103,8 +96,8 @@ public Layer(LayerArgs args)
10396
SupportsMasking = false;
10497

10598
_init_set_name(args.Name);
106-
trainableWeights = new List<IVariableV1>();
107-
nonTrainableWeights = new List<IVariableV1>();
99+
trainable_weights = new List<IVariableV1>();
100+
non_trainable_weights = new List<IVariableV1>();
108101
computePreviousMask = false;
109102
updates = new List<Operation>();
110103

src/TensorFlowNET.Core/Keras/Engine/TensorFlowOpLayer.cs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using NumSharp;
22
using System;
33
using System.Collections.Generic;
4+
using System.Linq;
45
using System.Text;
56
using Tensorflow.Graphs;
67
using Tensorflow.Keras.ArgsDefinition;
@@ -37,18 +38,21 @@ protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_tra
3738
}
3839

3940
[AutoGraph]
40-
Tensor _defun_call(Tensor inputs)
41+
Tensors _defun_call(Tensors inputs)
4142
=> MakOp(inputs);
4243

43-
Tensor MakOp(Tensor inputs)
44+
Tensors MakOp(Tensors inputs)
4445
{
4546
foreach (var (index, constant) in enumerate(constants))
4647
{
47-
48+
var value = constant_op.constant(constant, name: node_def.Input[index]);
49+
var new_inputs = inputs.ToList();
50+
new_inputs.Insert(index, value);
51+
inputs = new Tensors(new_inputs.ToArray());
4852
}
4953

5054
var graph = inputs.graph;
51-
var (c_op, c_op_desc) = ops._create_c_op(graph, node_def, new[] { inputs }, new Operation[0]);
55+
var (c_op, _) = ops._create_c_op(graph, node_def, inputs.ToArray(), new Operation[0]);
5256
var op = graph._create_op_from_tf_operation(c_op);
5357
op._control_flow_post_processing();
5458

src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,6 @@ public static IVariableV1 make_variable(VariableArgs args)
3737
var initializing_from_value = false;
3838
#pragma warning restore CS0219 // Variable is assigned but its value is never used
3939

40-
ops.init_scope();
41-
4240
Func<Tensor> init_val = () => args.Initializer.Apply(new InitializerArgs(args.Shape, dtype: args.DType));
4341

4442
var variable_dtype = args.DType.as_base_dtype();

src/TensorFlowNET.Core/Layers/Layer.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ public Layer(bool trainable = true,
4949
this._reuse = _reuse;
5050

5151
// Avoid an incorrect lint error
52-
trainableWeights = new List<IVariableV1>();
53-
nonTrainableWeights = new List<IVariableV1>();
52+
trainable_weights = new List<IVariableV1>();
53+
non_trainable_weights = new List<IVariableV1>();
5454
this.built = false;
5555
_keras_style = false;
5656
}

src/TensorFlowNET.Core/Operations/Operation.cs

Lines changed: 1 addition & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,7 @@ public Operation(NodeDef node_def, Graph g, Tensor[] inputs = null, TF_DataType[
170170
if (op_def == null)
171171
op_def = g.GetOpDef(node_def.Op);
172172

173-
var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr);
174-
(_handle, OpDesc) = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray());
173+
(_handle, OpDesc) = ops._create_c_op(g, node_def, inputs, control_input_ops.ToArray());
175174
_is_stateful = op_def.IsStateful;
176175

177176
// Initialize self._outputs.
@@ -194,39 +193,6 @@ public void run(FeedItem[] feed_dict = null, Session session = null)
194193
ops._run_using_default_session(this, feed_dict, graph, session);
195194
}
196195

197-
private object[] _reconstruct_sequence_inputs(OpDef op_def, Tensor[] inputs, MapField<string, AttrValue> attrs)
198-
{
199-
var grouped_inputs = new List<object>();
200-
int i = 0;
201-
int input_len = 0;
202-
bool is_sequence = false;
203-
foreach (var input_arg in op_def.InputArg)
204-
{
205-
if (!string.IsNullOrEmpty(input_arg.NumberAttr))
206-
{
207-
input_len = (int) attrs[input_arg.NumberAttr].I;
208-
is_sequence = true;
209-
} else if (!string.IsNullOrEmpty(input_arg.TypeListAttr))
210-
{
211-
input_len = attrs[input_arg.TypeListAttr].List.Type.Count;
212-
is_sequence = true;
213-
} else
214-
{
215-
input_len = 1;
216-
is_sequence = false;
217-
}
218-
219-
if (is_sequence)
220-
grouped_inputs.Add(inputs.Skip(i).Take(input_len).ToArray());
221-
else
222-
grouped_inputs.Add(inputs[i]);
223-
224-
i += input_len;
225-
}
226-
227-
return grouped_inputs.ToArray();
228-
}
229-
230196
public virtual T get_attr<T>(string name)
231197
=> (T)get_attr(name);
232198

src/TensorFlowNET.Core/Operations/gen_array_ops.cs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -490,10 +490,13 @@ public static Tensor transpose<T1, T2>(T1 x, T2 perm, string name = null)
490490
}
491491

492492
public static Tensor zeros_like(Tensor x, string name = null)
493-
{
494-
var _op = tf.OpDefLib._apply_op_helper("ZerosLike", name, new { x });
495-
return _op.outputs[0];
496-
}
493+
=> tf.Context.RunInAutoMode(()
494+
=> tf.OpDefLib._apply_op_helper("ZerosLike", name, new { x }).output, ()
495+
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
496+
"ZerosLike", name,
497+
null,
498+
x).FirstOrDefault(),
499+
x);
497500

498501
public static Tensor stop_gradient(Tensor x, string name = null)
499502
{

src/TensorFlowNET.Core/Operations/gen_math_ops.cs

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -568,11 +568,13 @@ public static Tensor less_equal<Tx, Ty>(Tx x, Ty y, string name = null)
568568
}
569569

570570
public static Tensor log1p(Tensor x, string name = null)
571-
{
572-
var _op = tf.OpDefLib._apply_op_helper("Log1p", name, args: new { x });
573-
574-
return _op.outputs[0];
575-
}
571+
=> tf.Context.RunInAutoMode(()
572+
=> tf.OpDefLib._apply_op_helper("Log1p", name: name, new { x }).output, ()
573+
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
574+
"Log1p", name,
575+
null,
576+
x).FirstOrDefault(),
577+
x);
576578

577579
public static Tensor logical_and(Tensor x, Tensor y, string name = null)
578580
{
@@ -1056,12 +1058,15 @@ public static Tensor _any<Tx, Ty>(Tx input, Ty axis, bool keep_dims = false, str
10561058
return _op.outputs[0];
10571059
}
10581060

1059-
public static Tensor _max<Tx, Ty>(Tx input, Ty axis, bool keep_dims=false, string name = null)
1060-
{
1061-
var _op = tf.OpDefLib._apply_op_helper("Max", name, new { input, reduction_indices = axis, keep_dims });
1062-
1063-
return _op.outputs[0];
1064-
}
1061+
public static Tensor _max<Tx, Ty>(Tx input, Ty axis, bool keep_dims = false, string name = null)
1062+
=> tf.Context.RunInAutoMode(()
1063+
=> tf.OpDefLib._apply_op_helper("Max", name, new { input, reduction_indices = axis, keep_dims }).output, ()
1064+
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
1065+
"Max", name,
1066+
null,
1067+
input, axis,
1068+
"keep_dims", keep_dims).FirstOrDefault(),
1069+
input as Tensor);
10651070

10661071
public static Tensor _min<Tx, Ty>(Tx input, Ty axis, bool keep_dims = false, string name = null)
10671072
{

0 commit comments

Comments
 (0)