@@ -121,7 +121,7 @@ public Layer(LayerArgs args)
121121 /// <param name="input"></param>
122122 /// <param name="is_training"></param>
123123 /// <returns></returns>
124- public Tensor Apply ( Tensor inputs , bool is_training = false , Tensor state = null )
124+ public Tensor Apply ( Tensor inputs , bool is_training = false )
125125 {
126126 Tensor outputs = null ;
127127
@@ -148,7 +148,7 @@ public Tensor Apply(Tensor inputs, bool is_training = false, Tensor state = null
148148 if ( ! built )
149149 MaybeBuild ( inputs ) ;
150150
151- outputs = call ( inputs , is_training : is_training , state : state ) ;
151+ outputs = call ( inputs , is_training : is_training ) ;
152152
153153 outputs = _set_connectivity_metadata_ ( inputs , outputs ) ;
154154 _handle_activity_regularization ( inputs , outputs ) ;
@@ -161,6 +161,35 @@ public Tensor Apply(Tensor inputs, bool is_training = false, Tensor state = null
161161 return outputs ;
162162 }
163163
164+ public Tensor [ ] Apply ( Tensor [ ] inputs , Tensor state , bool is_training = false )
165+ {
166+ Tensor [ ] outputs = null ;
167+
168+ callContext = callContext ?? new ThreadLocal < CallContext > ( )
169+ {
170+ Value = new CallContext ( )
171+ } ;
172+
173+ var eager = tf . executing_eagerly ( ) ;
174+ using var ctxManager = CallContext . enter ( ) ;
175+
176+ string nameScope = "" ;
177+ if ( eager )
178+ nameScope = name ;
179+ else
180+ nameScope = _name_scope ( ) ;
181+
182+ tf_with ( ops . name_scope ( nameScope ) , scope =>
183+ {
184+ if ( ! built )
185+ MaybeBuild ( inputs [ 0 ] ) ;
186+
187+ outputs = call ( inputs , is_training : is_training , state : state ) ;
188+ } ) ;
189+
190+ return outputs ;
191+ }
192+
164193 private Tensor _set_connectivity_metadata_ ( Tensor inputs , Tensor outputs )
165194 {
166195 /*var returnOutputs = new List<Tensor>();
@@ -200,7 +229,12 @@ private Tensor compute_mask(Tensor inputs, Tensor mask = null)
200229 return null ;
201230 }
202231
203- protected virtual Tensor call ( Tensor inputs , bool is_training = false , Tensor state = null )
232+ protected virtual Tensor call ( Tensor inputs , bool is_training = false )
233+ {
234+ throw new NotImplementedException ( "" ) ;
235+ }
236+
237+ protected virtual Tensor [ ] call ( Tensor [ ] inputs , Tensor state , bool is_training = false )
204238 {
205239 throw new NotImplementedException ( "" ) ;
206240 }
0 commit comments