Skip to content

Commit 373d981

Browse files
zxiaomzxmaymericdamien
authored andcommitted
Simpilify RNN examples data transformation (aymericdamien#136)
* gittest * Simpilify RNN examples data transform.
1 parent 00e2927 commit 373d981

File tree

5 files changed

+11
-31
lines changed

5 files changed

+11
-31
lines changed

examples/3_NeuralNetworks/bidirectional_rnn.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,8 @@ def BiRNN(x, weights, biases):
5555
# Current data input shape: (batch_size, n_steps, n_input)
5656
# Required shape: 'n_steps' tensors list of shape (batch_size, n_input)
5757

58-
# Permuting batch_size and n_steps
59-
x = tf.transpose(x, [1, 0, 2])
60-
# Reshape to (n_steps*batch_size, n_input)
61-
x = tf.reshape(x, [-1, n_input])
62-
# Split to get a list of 'n_steps' tensors of shape (batch_size, n_input)
63-
x = tf.split(x, n_steps, 0)
58+
# Unstack to get a list of 'n_steps' tensors of shape (batch_size, n_input)
59+
x = tf.unstack(x, n_steps, 1)
6460

6561
# Define lstm cells with tensorflow
6662
# Forward direction cell

examples/3_NeuralNetworks/dynamic_rnn.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -113,13 +113,9 @@ def dynamicRNN(x, seqlen, weights, biases):
113113
# Prepare data shape to match `rnn` function requirements
114114
# Current data input shape: (batch_size, n_steps, n_input)
115115
# Required shape: 'n_steps' tensors list of shape (batch_size, n_input)
116-
117-
# Permuting batch_size and n_steps
118-
x = tf.transpose(x, [1, 0, 2])
119-
# Reshaping to (n_steps*batch_size, n_input)
120-
x = tf.reshape(x, [-1, 1])
121-
# Split to get a list of 'n_steps' tensors of shape (batch_size, n_input)
122-
x = tf.split(axis=0, num_or_size_splits=seq_max_len, value=x)
116+
117+
# Unstack to get a list of 'n_steps' tensors of shape (batch_size, n_input)
118+
x = tf.unstack(x, seq_max_len, 1)
123119

124120
# Define a lstm cell with tensorflow
125121
lstm_cell = tf.contrib.rnn.BasicLSTMCell(n_hidden)

examples/3_NeuralNetworks/recurrent_network.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,8 @@ def RNN(x, weights, biases):
5353
# Current data input shape: (batch_size, n_steps, n_input)
5454
# Required shape: 'n_steps' tensors list of shape (batch_size, n_input)
5555

56-
# Permuting batch_size and n_steps
57-
x = tf.transpose(x, [1, 0, 2])
58-
# Reshaping to (n_steps*batch_size, n_input)
59-
x = tf.reshape(x, [-1, n_input])
60-
# Split to get a list of 'n_steps' tensors of shape (batch_size, n_input)
61-
x = tf.split(x, n_steps, 0)
56+
# Unstack to get a list of 'n_steps' tensors of shape (batch_size, n_input)
57+
x = tf.unstack(x, n_steps, 1)
6258

6359
# Define a lstm cell with tensorflow
6460
lstm_cell = rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)

notebooks/3_NeuralNetworks/bidirectional_rnn.ipynb

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -94,12 +94,8 @@
9494
" # Current data input shape: (batch_size, n_steps, n_input)\n",
9595
" # Required shape: 'n_steps' tensors list of shape (batch_size, n_input)\n",
9696
" \n",
97-
" # Permuting batch_size and n_steps\n",
98-
" x = tf.transpose(x, [1, 0, 2])\n",
99-
" # Reshape to (n_steps*batch_size, n_input)\n",
100-
" x = tf.reshape(x, [-1, n_input])\n",
101-
" # Split to get a list of 'n_steps' tensors of shape (batch_size, n_input)\n",
102-
" x = tf.split(x, n_steps, 0)\n",
97+
" # Unstack to get a list of 'n_steps' tensors of shape (batch_size, n_input)\n",
98+
" x = tf.unstack(x, n_steps, 1)\n",
10399
"\n",
104100
" # Define lstm cells with tensorflow\n",
105101
" # Forward direction cell\n",

notebooks/3_NeuralNetworks/recurrent_network.ipynb

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -93,12 +93,8 @@
9393
" # Current data input shape: (batch_size, n_steps, n_input)\n",
9494
" # Required shape: 'n_steps' tensors list of shape (batch_size, n_input)\n",
9595
" \n",
96-
" # Permuting batch_size and n_steps\n",
97-
" x = tf.transpose(x, [1, 0, 2])\n",
98-
" # Reshaping to (n_steps*batch_size, n_input)\n",
99-
" x = tf.reshape(x, [-1, n_input])\n",
100-
" # Split to get a list of 'n_steps' tensors of shape (batch_size, n_input)\n",
101-
" x = tf.split(x, n_steps, 0)\n",
96+
" # Unstack to get a list of 'n_steps' tensors of shape (batch_size, n_input)\n",
97+
" x = tf.unstack(x, n_steps, 1)\n",
10298
"\n",
10399
" # Define a lstm cell with tensorflow\n",
104100
" lstm_cell = rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)\n",

0 commit comments

Comments
 (0)