44using System . Text ;
55using Tensorflow . Keras . Engine ;
66using Tensorflow . Keras . Utils ;
7+ using Tensorflow . Train ;
78using static Tensorflow . Python ;
89
910namespace Tensorflow . Keras . Layers
@@ -14,7 +15,7 @@ namespace Tensorflow.Keras.Layers
1415 /// as convolution, batch norm, etc. These operations require managing weights,
1516 /// losses, updates, and inter-layer connectivity.
1617 /// </summary>
17- public class Layer : CheckpointableBase
18+ public class Layer : AutoTrackable
1819 {
1920 /// <summary>
2021 /// Indicates whether `build` needs to be called upon layer call, to create
@@ -84,32 +85,35 @@ public Tensor __call__(Tensor[] inputs,
8485 // models using the functional API).
8586 bool build_graph = tf_utils . are_all_symbolic_tensors ( input_list ) ;
8687
87- // Handle Keras mask propagation from previous layer to current layer.
88- Python . with ( ops . name_scope ( _name_scope ( ) ) , delegate
88+ if ( build_graph )
8989 {
90- /* if (!built)
91- {
92- _maybe_build(inputs);
93- built = true;
94- }*/
90+ // Only create Keras history if at least one tensor originates from a
91+ // `keras.Input`. Otherwise this Layer may be being used outside the Keras
92+ // framework.
93+ // base_layer_utils.create_keras_history(inputs)
94+ }
9595
96- if ( build_graph )
96+ // with base_layer_utils.call_context(self):
97+
98+ // Handle Keras mask propagation from previous layer to current layer.
99+ // with base_layer_utils.call_context(self):
100+ // Check input assumptions set after layer building, e.g. input shape.
101+ if ( build_graph )
102+ {
103+ // Symbolic execution on symbolic tensors. We will attempt to build
104+ // the corresponding TF subgraph inside `backend.get_graph()`
105+ var graph = backend . get_graph ( ) . as_default ( ) ;
106+ with ( ops . name_scope ( _name_scope ( ) ) , delegate
97107 {
98- // Symbolic execution on symbolic tensors. We will attempt to build
99- // the corresponding TF subgraph inside `backend.get_graph()`
100- var graph = backend . get_graph ( ) . as_default ( ) ;
101- with ( ops . name_scope ( _name_scope ( ) ) , delegate
102- {
103- // Build layer if applicable (if the `build` method has been
104- // overridden).
105- _maybe_build ( inputs [ 0 ] ) ;
106- } ) ;
107-
108- outputs = call ( inputs [ 0 ] , training : training ) ;
109- _handle_activity_regularization ( inputs [ 0 ] , outputs ) ;
110- _set_mask_metadata ( inputs [ 0 ] , outputs , null ) ;
111- }
112- } ) ;
108+ // Build layer if applicable (if the `build` method has been
109+ // overridden).
110+ _maybe_build ( inputs [ 0 ] ) ;
111+ } ) ;
112+
113+ outputs = call ( inputs [ 0 ] , training : training ) ;
114+ _handle_activity_regularization ( inputs [ 0 ] , outputs ) ;
115+ _set_mask_metadata ( inputs [ 0 ] , outputs , null ) ;
116+ }
113117
114118 return outputs ;
115119 }
@@ -147,6 +151,8 @@ protected void _maybe_build(Tensor input)
147151 // Check input assumptions set before layer building, e.g. input rank.
148152 if ( built )
149153 return ;
154+ if ( _dtype == TF_DataType . DtInvalid )
155+ _dtype = input . dtype ;
150156
151157 build ( input . GetShape ( ) ) ;
152158 built = true ;
@@ -170,10 +176,21 @@ protected virtual RefVariable add_weight(string name,
170176 if ( trainable == null )
171177 trainable = true ;
172178
179+ // Initialize variable when no initializer provided
180+ if ( initializer == null )
181+ {
182+ // If dtype is DT_FLOAT, provide a uniform unit scaling initializer
183+ if ( dtype . is_floating ( ) )
184+ initializer = tf . glorot_uniform_initializer ;
185+ else if ( dtype . is_integer ( ) )
186+ initializer = tf . zeros_initializer ;
187+ else
188+ throw new ValueError ( $ "An initializer for variable { name } of type { dtype . as_base_dtype ( ) } is required for layer { this . name } ") ;
189+ }
173190 var variable = _add_variable_with_custom_getter ( name ,
174191 shape ,
175192 dtype : dtype ,
176- // getter: getter == null ? base_layer_utils.make_variable : getter,
193+ getter : getter , // getter == null ? base_layer_utils.make_variable : getter,
177194 overwrite : true ,
178195 initializer : initializer ,
179196 trainable : trainable . Value ) ;
0 commit comments