@@ -178,8 +178,13 @@ def __call__(self, inputs, state, scope=None):
178178 custom_getter = self ._rnn_get_variable ) as scope :
179179 return super (RNNCell , self ).__call__ (inputs , state , scope = scope )
180180 else :
181- with vs .variable_scope (vs .get_variable_scope (),
182- custom_getter = self ._rnn_get_variable ):
181+ scope_attrname = "rnncell_scope"
182+ scope = getattr (self , scope_attrname , None )
183+ if scope is None :
184+ scope = vs .variable_scope (vs .get_variable_scope (),
185+ custom_getter = self ._rnn_get_variable )
186+ setattr (self , scope_attrname , scope )
187+ with scope :
183188 return super (RNNCell , self ).__call__ (inputs , state )
184189
185190 def _rnn_get_variable (self , getter , * args , ** kwargs ):
@@ -230,9 +235,20 @@ def zero_state(self, batch_size, dtype):
230235 a nested list or tuple (of the same structure) of `2-D` tensors with
231236 the shapes `[batch_size x s]` for each s in `state_size`.
232237 """
238+ # Try to use the last cached zero_state. This is done to avoid recreating
239+ # zeros, especially when eager execution is enabled.
240+ state_size = self .state_size
241+ if hasattr (self , "_last_zero_state" ):
242+ (last_state_size , last_batch_size , last_dtype ,
243+ last_output ) = getattr (self , "_last_zero_state" )
244+ if (last_batch_size == batch_size and
245+ last_dtype == dtype and
246+ last_state_size == state_size ):
247+ return last_output
233248 with ops .name_scope (type (self ).__name__ + "ZeroState" , values = [batch_size ]):
234- state_size = self .state_size
235- return _zero_state_tensors (state_size , batch_size , dtype )
249+ output = _zero_state_tensors (state_size , batch_size , dtype )
250+ self ._last_zero_state = (state_size , batch_size , dtype , output )
251+ return output
236252
237253
238254class BasicRNNCell (RNNCell ):
@@ -428,21 +444,27 @@ def call(self, inputs, state):
428444 `state_is_tuple`).
429445 """
430446 sigmoid = math_ops .sigmoid
447+ one = constant_op .constant (1 , dtype = dtypes .int32 )
431448 # Parameters of gates are concatenated into one multiply for efficiency.
432449 if self ._state_is_tuple :
433450 c , h = state
434451 else :
435- c , h = array_ops .split (value = state , num_or_size_splits = 2 , axis = 1 )
452+ c , h = array_ops .split (value = state , num_or_size_splits = 2 , axis = one )
436453
437454 if self ._linear is None :
438455 self ._linear = _Linear ([inputs , h ], 4 * self ._num_units , True )
439456 # i = input_gate, j = new_input, f = forget_gate, o = output_gate
440457 i , j , f , o = array_ops .split (
441- value = self ._linear ([inputs , h ]), num_or_size_splits = 4 , axis = 1 )
458+ value = self ._linear ([inputs , h ]), num_or_size_splits = 4 , axis = one )
442459
443- new_c = (
444- c * sigmoid (f + self ._forget_bias ) + sigmoid (i ) * self ._activation (j ))
445- new_h = self ._activation (new_c ) * sigmoid (o )
460+ forget_bias_tensor = constant_op .constant (self ._forget_bias , dtype = f .dtype )
461+ # Note that using `add` and `multiply` instead of `+` and `*` gives a
462+ # performance improvement. So using those at the cost of readability.
463+ add = math_ops .add
464+ multiply = math_ops .multiply
465+ new_c = add (multiply (c , sigmoid (add (f , forget_bias_tensor ))),
466+ multiply (sigmoid (i ), self ._activation (j )))
467+ new_h = multiply (self ._activation (new_c ), sigmoid (o ))
446468
447469 if self ._state_is_tuple :
448470 new_state = LSTMStateTuple (new_c , new_h )
@@ -1186,7 +1208,9 @@ def __call__(self, args):
11861208 if len (args ) == 1 :
11871209 res = math_ops .matmul (args [0 ], self ._weights )
11881210 else :
1189- res = math_ops .matmul (array_ops .concat (args , 1 ), self ._weights )
1211+ # Explicitly creating a one for a minor performance improvement.
1212+ one = constant_op .constant (1 , dtype = dtypes .int32 )
1213+ res = math_ops .matmul (array_ops .concat (args , one ), self ._weights )
11901214 if self ._build_bias :
11911215 res = nn_ops .bias_add (res , self ._biases )
11921216 return res
0 commit comments