Skip to content

Commit 4cce9f6

Browse files
committed
added dynamic RNN
1 parent 1f58605 commit 4cce9f6

File tree

2 files changed

+196
-0
lines changed

2 files changed

+196
-0
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ It is suitable for beginners who want to find clear and concise examples about T
2323
- Convolutional Neural Network ([notebook](https://github.com/aymericdamien/TensorFlow-Examples/blob/master/notebooks/3_NeuralNetworks/convolutional_network.ipynb)) ([code](https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/3_NeuralNetworks/convolutional_network.py))
2424
- Recurrent Neural Network (LSTM) ([notebook](https://github.com/aymericdamien/TensorFlow-Examples/blob/master/notebooks/3_NeuralNetworks/recurrent_network.ipynb)) ([code](https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/3_NeuralNetworks/recurrent_network.py))
2525
- Bidirectional Recurrent Neural Network (LSTM) ([notebook](https://github.com/aymericdamien/TensorFlow-Examples/blob/master/notebooks/3_NeuralNetworks/bidirectional_rnn.ipynb)) ([code](https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/3_NeuralNetworks/bidirectional_rnn.py))
26+
- Dynamic Recurrent Neural Network (LSTM) ([code](https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/3_NeuralNetworks/dynamic_rnn.py))
2627
- AutoEncoder ([notebook](https://github.com/aymericdamien/TensorFlow-Examples/blob/master/notebooks/3_NeuralNetworks/autoencoder.ipynb)) ([code](https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/3_NeuralNetworks/autoencoder.py))
2728

2829
#### 4 - Utilities
Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
'''
2+
A Dynamic Reccurent Neural Network (LSTM) implementation example using
3+
TensorFlow library. This example is using a toy dataset to classify linear
4+
sequences. The generated sequences have variable length.
5+
6+
Long Short Term Memory paper: http://deeplearning.cs.cmu.edu/pdfs/Hochreiter97_lstm.pdf
7+
8+
Author: Aymeric Damien
9+
Project: https://github.com/aymericdamien/TensorFlow-Examples/
10+
'''
11+
12+
import tensorflow as tf
13+
import random
14+
15+
16+
# ====================
17+
# TOY DATA GENERATOR
18+
# ====================
19+
class ToySequenceData(object):
20+
""" Generate sequence of data with dynamic length.
21+
This class generate samples for training:
22+
- Class 0: linear sequences (i.e. [0, 1, 2, 3,...])
23+
- Class 1: random sequences (i.e. [1, 3, 10, 7,...])
24+
25+
NOTICE:
26+
We have to pad each sequence to reach 'max_seq_len' for TensorFlow
27+
consistency (we cannot feed a numpy array with unconsistent
28+
dimensions). The dynamic calculation will then be perform thanks to
29+
'seqlen' attribute that records every actual sequence length.
30+
"""
31+
def __init__(self, n_samples=1000, max_seq_len=20, min_seq_len=3,
32+
max_value=1000):
33+
self.data = []
34+
self.labels = []
35+
self.seqlen = []
36+
for i in range(n_samples):
37+
# Random sequence length
38+
len = random.randint(min_seq_len, max_seq_len)
39+
# Monitor sequence length for TensorFlow dynamic calculation
40+
self.seqlen.append(len)
41+
# Add a random or linear int sequence (50% prob)
42+
if random.random() < .5:
43+
# Generate a linear sequence
44+
rand_start = random.randint(0, max_value - len)
45+
s = [[float(i)/max_value] for i in
46+
range(rand_start, rand_start + len)]
47+
# Pad sequence for dimension consistency
48+
s += [[0.] for i in range(max_seq_len - len)]
49+
self.data.append(s)
50+
self.labels.append([1., 0.])
51+
else:
52+
# Generate a random sequence
53+
s = [[float(random.randint(0, max_value))/max_value]
54+
for i in range(len)]
55+
# Pad sequence for dimension consistency
56+
s += [[0.] for i in range(max_seq_len - len)]
57+
self.data.append(s)
58+
self.labels.append([0., 1.])
59+
self.batch_id = 0
60+
61+
def next(self, batch_size):
62+
""" Return a batch of data. When dataset end is reached, start over.
63+
"""
64+
if self.batch_id == len(self.data):
65+
self.batch_id = 0
66+
batch_data = (self.data[self.batch_id:min(self.batch_id +
67+
batch_size, len(self.data))])
68+
batch_labels = (self.labels[self.batch_id:min(self.batch_id +
69+
batch_size, len(self.data))])
70+
batch_seqlen = (self.seqlen[self.batch_id:min(self.batch_id +
71+
batch_size, len(self.data))])
72+
self.batch_id = min(self.batch_id + batch_size, len(self.data))
73+
return batch_data, batch_labels, batch_seqlen
74+
75+
76+
# ==========
77+
# MODEL
78+
# ==========
79+
80+
# Parameters
81+
learning_rate = 0.01
82+
training_iters = 1000000
83+
batch_size = 128
84+
display_step = 10
85+
86+
# Network Parameters
87+
seq_max_len = 20 # Sequence max length
88+
n_hidden = 64 # hidden layer num of features
89+
n_classes = 2 # linear sequence or not
90+
91+
trainset = ToySequenceData(n_samples=1000, max_seq_len=seq_max_len)
92+
testset = ToySequenceData(n_samples=500, max_seq_len=seq_max_len)
93+
94+
# tf Graph input
95+
x = tf.placeholder("float", [None, seq_max_len, 1])
96+
y = tf.placeholder("float", [None, n_classes])
97+
# A placeholder for indicating each sequence length
98+
seqlen = tf.placeholder(tf.int32, [None])
99+
100+
# Define weights
101+
weights = {
102+
'out': tf.Variable(tf.random_normal([n_hidden, n_classes]))
103+
}
104+
biases = {
105+
'out': tf.Variable(tf.random_normal([n_classes]))
106+
}
107+
108+
109+
def dynamicRNN(x, seqlen, weights, biases):
110+
111+
# Prepare data shape to match `rnn` function requirements
112+
# Current data input shape: (batch_size, n_steps, n_input)
113+
# Required shape: 'n_steps' tensors list of shape (batch_size, n_input)
114+
115+
# Permuting batch_size and n_steps
116+
x = tf.transpose(x, [1, 0, 2])
117+
# Reshaping to (n_steps*batch_size, n_input)
118+
x = tf.reshape(x, [-1, 1])
119+
# Split to get a list of 'n_steps' tensors of shape (batch_size, n_input)
120+
x = tf.split(0, seq_max_len, x)
121+
122+
# Define a lstm cell with tensorflow
123+
lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(n_hidden)
124+
125+
# Get lstm cell output, providing 'sequence_length' will perform dynamic
126+
# calculation.
127+
outputs, states = tf.nn.rnn(lstm_cell, x, dtype=tf.float32,
128+
sequence_length=seqlen)
129+
130+
# When performing dynamic calculation, we must retrieve the last
131+
# dynamically computed output, i.e, if a sequence length is 10, we need
132+
# to retrieve the 10th output.
133+
# However TensorFlow doesn't support advanced indexing yet, so we build
134+
# a custom op that for each sample in batch size, get its length and
135+
# get the corresponding relevant output.
136+
137+
# 'outputs' is a list of output at every timestep, we pack them in a Tensor
138+
# and change back dimension to [batch_size, n_step, n_input]
139+
outputs = tf.pack(outputs)
140+
outputs = tf.transpose(outputs, [1, 0, 2])
141+
142+
# Hack to build the indexing and retrieve the right output.
143+
batch_size = tf.shape(outputs)[0]
144+
# Start indices for each sample
145+
index = tf.range(0, batch_size) * seq_max_len + (seqlen - 1)
146+
# Indexing
147+
outputs = tf.gather(tf.reshape(outputs, [-1, n_hidden]), index)
148+
149+
# Linear activation, using rnn inner loop last output
150+
return tf.matmul(outputs, weights['out']) + biases['out']
151+
152+
pred = dynamicRNN(x, seqlen, weights, biases)
153+
154+
# Define loss and optimizer
155+
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(pred, y))
156+
optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate).minimize(cost)
157+
158+
# Evaluate model
159+
correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))
160+
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
161+
162+
# Initializing the variables
163+
init = tf.initialize_all_variables()
164+
165+
# Launch the graph
166+
with tf.Session() as sess:
167+
sess.run(init)
168+
step = 1
169+
# Keep training until reach max iterations
170+
while step * batch_size < training_iters:
171+
batch_x, batch_y, batch_seqlen = trainset.next(batch_size)
172+
# Run optimization op (backprop)
173+
sess.run(optimizer, feed_dict={x: batch_x, y: batch_y,
174+
seqlen: batch_seqlen})
175+
if step % display_step == 0:
176+
# Calculate batch accuracy
177+
acc = sess.run(accuracy, feed_dict={x: batch_x, y: batch_y,
178+
seqlen: batch_seqlen})
179+
# Calculate batch loss
180+
loss = sess.run(cost, feed_dict={x: batch_x, y: batch_y,
181+
seqlen: batch_seqlen})
182+
print "Iter " + str(step*batch_size) + ", Minibatch Loss= " + \
183+
"{:.6f}".format(loss) + ", Training Accuracy= " + \
184+
"{:.5f}".format(acc)
185+
step += 1
186+
print "Optimization Finished!"
187+
188+
# Calculate accuracy for 128 mnist test images
189+
test_len = 128
190+
test_data = testset.data
191+
test_label = testset.labels
192+
test_seqlen = testset.seqlen
193+
print "Testing Accuracy:", \
194+
sess.run(accuracy, feed_dict={x: test_data, y: test_label,
195+
seqlen: test_seqlen})

0 commit comments

Comments
 (0)