Skip to content

Commit d7c7d3d

Browse files
committed
fix tf.layers.conv2d
1 parent a5ae56a commit d7c7d3d

11 files changed

Lines changed: 112 additions & 44 deletions

File tree

src/TensorFlowNET.Core/APIs/tf.init.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,15 @@ public static partial class tf
1414

1515
public static variable_scope variable_scope(string name,
1616
string default_name = null,
17-
object values = null,
17+
Tensor[] values = null,
1818
bool auxiliary_name_scope = true) => new variable_scope(name,
1919
default_name,
2020
values,
2121
auxiliary_name_scope);
2222

2323
public static variable_scope variable_scope(VariableScope scope,
2424
string default_name = null,
25-
object values = null,
25+
Tensor[] values = null,
2626
bool? reuse = null,
2727
bool auxiliary_name_scope = true) => new variable_scope(scope,
2828
default_name,

src/TensorFlowNET.Core/Keras/Layers/Layer.cs

Lines changed: 42 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
using System.Text;
55
using Tensorflow.Keras.Engine;
66
using Tensorflow.Keras.Utils;
7+
using Tensorflow.Train;
78
using static Tensorflow.Python;
89

910
namespace Tensorflow.Keras.Layers
@@ -14,7 +15,7 @@ namespace Tensorflow.Keras.Layers
1415
/// as convolution, batch norm, etc. These operations require managing weights,
1516
/// losses, updates, and inter-layer connectivity.
1617
/// </summary>
17-
public class Layer : CheckpointableBase
18+
public class Layer : AutoTrackable
1819
{
1920
/// <summary>
2021
/// Indicates whether `build` needs to be called upon layer call, to create
@@ -84,32 +85,35 @@ public Tensor __call__(Tensor[] inputs,
8485
// models using the functional API).
8586
bool build_graph = tf_utils.are_all_symbolic_tensors(input_list);
8687

87-
// Handle Keras mask propagation from previous layer to current layer.
88-
Python.with(ops.name_scope(_name_scope()), delegate
88+
if (build_graph)
8989
{
90-
/*if (!built)
91-
{
92-
_maybe_build(inputs);
93-
built = true;
94-
}*/
90+
// Only create Keras history if at least one tensor originates from a
91+
// `keras.Input`. Otherwise this Layer may be being used outside the Keras
92+
// framework.
93+
// base_layer_utils.create_keras_history(inputs)
94+
}
9595

96-
if (build_graph)
96+
// with base_layer_utils.call_context(self):
97+
98+
// Handle Keras mask propagation from previous layer to current layer.
99+
// with base_layer_utils.call_context(self):
100+
// Check input assumptions set after layer building, e.g. input shape.
101+
if (build_graph)
102+
{
103+
// Symbolic execution on symbolic tensors. We will attempt to build
104+
// the corresponding TF subgraph inside `backend.get_graph()`
105+
var graph = backend.get_graph().as_default();
106+
with(ops.name_scope(_name_scope()), delegate
97107
{
98-
// Symbolic execution on symbolic tensors. We will attempt to build
99-
// the corresponding TF subgraph inside `backend.get_graph()`
100-
var graph = backend.get_graph().as_default();
101-
with(ops.name_scope(_name_scope()), delegate
102-
{
103-
// Build layer if applicable (if the `build` method has been
104-
// overridden).
105-
_maybe_build(inputs[0]);
106-
});
107-
108-
outputs = call(inputs[0], training: training);
109-
_handle_activity_regularization(inputs[0], outputs);
110-
_set_mask_metadata(inputs[0], outputs, null);
111-
}
112-
});
108+
// Build layer if applicable (if the `build` method has been
109+
// overridden).
110+
_maybe_build(inputs[0]);
111+
});
112+
113+
outputs = call(inputs[0], training: training);
114+
_handle_activity_regularization(inputs[0], outputs);
115+
_set_mask_metadata(inputs[0], outputs, null);
116+
}
113117

114118
return outputs;
115119
}
@@ -147,6 +151,8 @@ protected void _maybe_build(Tensor input)
147151
// Check input assumptions set before layer building, e.g. input rank.
148152
if (built)
149153
return;
154+
if (_dtype == TF_DataType.DtInvalid)
155+
_dtype = input.dtype;
150156

151157
build(input.GetShape());
152158
built = true;
@@ -170,10 +176,21 @@ protected virtual RefVariable add_weight(string name,
170176
if (trainable == null)
171177
trainable = true;
172178

179+
// Initialize variable when no initializer provided
180+
if(initializer == null)
181+
{
182+
// If dtype is DT_FLOAT, provide a uniform unit scaling initializer
183+
if (dtype.is_floating())
184+
initializer = tf.glorot_uniform_initializer;
185+
else if (dtype.is_integer())
186+
initializer = tf.zeros_initializer;
187+
else
188+
throw new ValueError($"An initializer for variable {name} of type {dtype.as_base_dtype()} is required for layer {this.name}");
189+
}
173190
var variable = _add_variable_with_custom_getter(name,
174191
shape,
175192
dtype: dtype,
176-
//getter: getter == null ? base_layer_utils.make_variable : getter,
193+
getter: getter, // getter == null ? base_layer_utils.make_variable : getter,
177194
overwrite: true,
178195
initializer: initializer,
179196
trainable: trainable.Value);

src/TensorFlowNET.Core/Keras/backend.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@ public class backend
1212
/// Allows to give unique autogenerated names to layers, in a graph-specific way.
1313
/// </summary>
1414
public static Dictionary<string, Dictionary<(string, string), int>> PER_GRAPH_LAYER_NAME_UIDS = new Dictionary<string, Dictionary<(string, string), int>>();
15-
15+
public static Dictionary<string, RefVariable> _GRAPH_VARIABLES = new Dictionary<string, RefVariable>();
1616
public static void track_variable(RefVariable v)
1717
{
18-
18+
var graph = v.graph;
19+
_GRAPH_VARIABLES[graph.graph_key] = v;
1920
}
2021

2122
public static Tensor placeholder(int[] shape = null,

src/TensorFlowNET.Core/Layers/Layer.cs

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,14 @@ public Tensor __call__(Tensor inputs,
5151
auxiliary_name_scope: false);
5252
}
5353

54-
with(scope_context_manager, scope2 => _current_scope = scope2);
55-
// Actually call layer
56-
var outputs = base.__call__(new Tensor[] { inputs }, training: training);
54+
Tensor outputs = null;
55+
with(scope_context_manager, scope2 =>
56+
{
57+
_current_scope = scope2;
58+
// Actually call layer
59+
outputs = base.__call__(new Tensor[] { inputs }, training: training);
60+
});
61+
5762

5863
// Update global default collections.
5964
_add_elements_to_collection(_updates.ToArray(), new string[] { ops.GraphKeys.UPDATE_OPS });
@@ -80,6 +85,17 @@ protected virtual void _add_elements_to_collection(Operation[] elements, string[
8085
}
8186
}
8287

88+
/// <summary>
89+
/// Adds a new variable to the layer, or gets an existing one; returns it.
90+
/// </summary>
91+
/// <param name="name"></param>
92+
/// <param name="shape"></param>
93+
/// <param name="dtype"></param>
94+
/// <param name="initializer"></param>
95+
/// <param name="trainable"></param>
96+
/// <param name="synchronization"></param>
97+
/// <param name="aggregation"></param>
98+
/// <returns></returns>
8399
protected virtual RefVariable add_weight(string name,
84100
int[] shape,
85101
TF_DataType dtype = TF_DataType.DtInvalid,
@@ -157,7 +173,10 @@ private void _set_scope(VariableScope scope = null)
157173
else
158174
{
159175
with(tf.variable_scope(scope, default_name: _base_name),
160-
captured_scope => _scope = captured_scope);
176+
captured_scope =>
177+
{
178+
_scope = captured_scope;
179+
});
161180
}
162181

163182
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Train
6+
{
7+
public abstract class AutoTrackable : Trackable
8+
{
9+
}
10+
}

src/TensorFlowNET.Core/Train/Checkpointable/CheckpointableBase.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
namespace Tensorflow
77
{
8-
public abstract class CheckpointableBase : Trackable
8+
public abstract class CheckpointableBase : AutoTrackable
99
{
1010

1111
}

src/TensorFlowNET.Core/Train/Trackable.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,13 @@ protected virtual RefVariable _add_variable_with_custom_getter(string name,
1818
bool overwrite = false,
1919
bool trainable = false)
2020
{
21+
var checkpoint_initializer = true;
2122
var new_variable = getter(name, shape, dtype, initializer, trainable);
23+
24+
// If we set an initializer and the variable processed it, tracking will not
25+
// assign again. It will add this variable to our dependencies, and if there
26+
// is a non-trivial restoration queued, it will handle that. This also
27+
// handles slot variables.
2228
if (!overwrite || new_variable is RefVariable)
2329
return _track_checkpointable(new_variable, name: name,
2430
overwrite: overwrite);

src/TensorFlowNET.Core/Variables/PureVariableScope.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ public PureVariableScope(VariableScope scope,
3535
_old_name_scope = old_name_scope;
3636
_var_store = variable_scope._get_default_variable_store();
3737
_var_scope_store = variable_scope.get_variable_scope_store();
38-
_new_name = _scope._name;
38+
_new_name = _scope.name;
3939

4040
string name_scope = _scope._name_scope;
4141
variable_scope_object = new VariableScope(_reuse,
@@ -55,7 +55,7 @@ public void __enter__()
5555
}
5656
else
5757
{
58-
_new_name = string.IsNullOrEmpty(_old._name) ? _name : _old._name + "/" + _name;
58+
_new_name = string.IsNullOrEmpty(_old.name) ? _name : _old.name + "/" + _name;
5959
_reuse = _reuse || _old.resue;
6060
string name_scope = _old_name_scope == null ? _name : _old_name_scope;
6161

src/TensorFlowNET.Core/Variables/VariableScope.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ public class VariableScope
1515
public bool resue;
1616

1717
private TF_DataType _dtype;
18-
public string _name { get; set; }
18+
string _name;
19+
public string name => _name;
1920
public string _name_scope { get; set; }
2021
public string original_name_scope => _name_scope;
2122

src/TensorFlowNET.Core/Variables/variable_scope.py.cs

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,17 @@ public class variable_scope : IPython
1919
private string _name;
2020
private VariableScope _scope;
2121
private string _default_name;
22-
private object _values;
22+
private Tensor[] _values;
2323
private ops.NameScope _current_name_scope;
2424
private bool _auxiliary_name_scope;
2525
private PureVariableScope _cached_pure_variable_scope;
2626
private bool? _reuse;
27+
bool _in_graph_mode;
28+
protected Graph _graph;
2729

2830
public variable_scope(string name,
29-
string default_name = "",
30-
object values = null,
31+
string default_name = "",
32+
Tensor[] values = null,
3133
bool? reuse = null,
3234
bool auxiliary_name_scope = true)
3335
{
@@ -45,7 +47,7 @@ public variable_scope(string name,
4547

4648
public variable_scope(VariableScope scope,
4749
string default_name = "",
48-
object values = null,
50+
Tensor[] values = null,
4951
bool? reuse = null,
5052
bool auxiliary_name_scope = true)
5153
{
@@ -58,6 +60,11 @@ public variable_scope(VariableScope scope,
5860
if (_default_name == null && _scope == null)
5961
throw new TypeError("If default_name is None then scope is required");
6062

63+
if (_values == null)
64+
_values = new Tensor[0];
65+
_in_graph_mode = true;
66+
if (_in_graph_mode)
67+
_graph = ops._get_graph_from_inputs(_values);
6168
_auxiliary_name_scope = auxiliary_name_scope;
6269
}
6370

@@ -87,7 +94,7 @@ private VariableScope _enter_scope_uncached()
8794

8895
if (_name != null || _scope != null)
8996
{
90-
var name_scope = _name == null ? _scope._name.Split('/').Last() : _name;
97+
var name_scope = _name == null ? _scope.name.Split('/').Last() : _name;
9198
if (name_scope != null || current_name_scope != null)
9299
current_name_scope = ops.name_scope(name_scope);
93100
current_name_scope.__enter__();
@@ -124,7 +131,7 @@ public static string _get_unique_variable_scope(string prefix)
124131
{
125132
var var_scope_store = get_variable_scope_store();
126133
var current_scope = get_variable_scope();
127-
string name = !string.IsNullOrEmpty(current_scope._name) ? current_scope._name + "/" + prefix : prefix;
134+
string name = !string.IsNullOrEmpty(current_scope.name) ? current_scope.name + "/" + prefix : prefix;
128135
if (var_scope_store.variable_scope_count(name) == 0)
129136
return prefix;
130137
throw new NotImplementedException("_get_unique_variable_scope");

0 commit comments

Comments
 (0)