@@ -12,20 +12,16 @@ public class SimpleRnnTest
1212 {
1313 public void Run ( )
1414 {
15- tf . keras = new KerasInterface ( ) ;
16- var inputs = np . random . random ( ( 32 , 10 , 8 ) ) . astype ( np . float32 ) ;
17- var simple_rnn = tf . keras . layers . SimpleRNN ( 4 ) ;
18- var output = simple_rnn . Apply ( inputs ) ; // The output has shape `[32, 4]`.
19- if ( output . shape == ( 32 , 4 ) )
20- {
15+ tf . UseKeras < KerasInterface > ( ) ;
16+ var inputs = np . random . random ( ( 6 , 10 , 8 ) ) . astype ( np . float32 ) ;
17+ //var simple_rnn = tf.keras.layers.SimpleRNN(4);
18+ //var output = simple_rnn.Apply(inputs); // The output has shape `[32, 4]`.
2119
22- }
23- /*simple_rnn = tf.keras.layers.SimpleRNN(
24- 4, return_sequences = True, return_state = True)
20+ var simple_rnn = tf . keras . layers . SimpleRNN ( 4 , return_sequences : true , return_state : true ) ;
2521
26- # whole_sequence_output has shape `[32, 10, 4]`.
27- # final_state has shape `[32, 4]`.
28- whole_sequence_output, final_state = simple_rnn(inputs)*/
22+ // whole_sequence_output has shape `[32, 10, 4]`.
23+ // final_state has shape `[32, 4]`.
24+ var ( whole_sequence_output , final_state ) = simple_rnn . Apply ( inputs ) ;
2925 }
3026 }
3127}
0 commit comments