Skip to content

Commit 2969ed0

Browse files
authored
Merge pull request nlintz#46 from j-min/master
Fixed tensorflow 0.9 API compatibility
2 parents aeec3fb + 9e1218b commit 2969ed0

2 files changed

Lines changed: 8 additions & 14 deletions

File tree

.travis.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@ install:
1111
- pip install matplotlib
1212
# install TensorFlow from https://storage.googleapis.com/tensorflow/
1313
- if [[ "$TRAVIS_PYTHON_VERSION" == "2.7" ]]; then
14-
pip install https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.8.0-cp27-none-linux_x86_64.whl;
14+
pip install https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.9.0-cp27-none-linux_x86_64.whl;
1515
elif [[ "$TRAVIS_PYTHON_VERSION" == "3.4" ]]; then
16-
pip install https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.8.0-cp34-cp34m-linux_x86_64.whl;
16+
pip install https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.9.0-cp34-cp34m-linux_x86_64.whl;
1717
fi
1818
script:
1919
- sed -i -- 's/range(100)/range(1)/g' ??_*.py # change range to 1 for quick testing

07_lstm.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
#Inspired by https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/3%20-%20Neural%20Networks/recurrent_network.py
22
import tensorflow as tf
3-
from tensorflow.models.rnn import rnn, rnn_cell
43

54
import numpy as np
65
import input_data
@@ -32,7 +31,7 @@ def init_weights(shape):
3231
return tf.Variable(tf.random_normal(shape, stddev=0.01))
3332

3433

35-
def model(X, W, B, init_state, lstm_size):
34+
def model(X, W, B, lstm_size):
3635
# X, input shape: (batch_size, input_vec_size, time_step_size)
3736
XT = tf.transpose(X, [1, 0, 2]) # permute time_step_size and batch_size
3837
# XT shape: (input_vec_size, batch_szie, time_step_size)
@@ -42,10 +41,10 @@ def model(X, W, B, init_state, lstm_size):
4241
# Each array shape: (batch_size, input_vec_size)
4342

4443
# Make lstm with lstm_size (each input vector size)
45-
lstm = rnn_cell.BasicLSTMCell(lstm_size, forget_bias=1.0)
44+
lstm = tf.nn.rnn_cell.BasicLSTMCell(lstm_size, forget_bias=1.0, state_is_tuple=True)
4645

4746
# Get lstm cell output, time_step_size (28) arrays with lstm_size output: (batch_size, lstm_size)
48-
outputs, _states = rnn.rnn(lstm, X_split, initial_state=init_state)
47+
outputs, _states = tf.nn.rnn(lstm, X_split, dtype=tf.float32)
4948

5049
# Linear activation
5150
# Get the last output
@@ -56,17 +55,14 @@ def model(X, W, B, init_state, lstm_size):
5655
trX = trX.reshape(-1, 28, 28)
5756
teX = teX.reshape(-1, 28, 28)
5857

59-
# Tensorflow LSTM cell requires 2x n_hidden length (state & cell)
60-
init_state = tf.placeholder("float", [None, 2*lstm_size])
61-
6258
X = tf.placeholder("float", [None, 28, 28])
6359
Y = tf.placeholder("float", [None, 10])
6460

6561
# get lstm_size and output 10 labels
6662
W = init_weights([lstm_size, 10])
6763
B = init_weights([10])
6864

69-
py_x, state_size = model(X, W, B, init_state, lstm_size)
65+
py_x, state_size = model(X, W, B, lstm_size)
7066

7167
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(py_x, Y))
7268
train_op = tf.train.RMSPropOptimizer(0.001, 0.9).minimize(cost)
@@ -79,14 +75,12 @@ def model(X, W, B, init_state, lstm_size):
7975

8076
for i in range(100):
8177
for start, end in zip(range(0, len(trX), batch_size), range(batch_size, len(trX), batch_size)):
82-
sess.run(train_op, feed_dict={X: trX[start:end], Y: trY[start:end],
83-
init_state: np.zeros((batch_size, state_size))})
78+
sess.run(train_op, feed_dict={X: trX[start:end], Y: trY[start:end]})
8479

8580
test_indices = np.arange(len(teX)) # Get A Test Batch
8681
np.random.shuffle(test_indices)
8782
test_indices = test_indices[0:test_size]
8883

8984
print(i, np.mean(np.argmax(teY[test_indices], axis=1) ==
9085
sess.run(predict_op, feed_dict={X: teX[test_indices],
91-
Y: teY[test_indices],
92-
init_state: np.zeros((test_size, state_size))})))
86+
Y: teY[test_indices]})))

0 commit comments

Comments
 (0)