Skip to content

Commit d2d0e33

Browse files
committed
Added RNN 01 thanks to Jenny
1 parent d8b6f9d commit d2d0e33

2 files changed

Lines changed: 49 additions & 41 deletions

File tree

lab-12-1-hello-rnn.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Lab 12 RNN
2+
import tensorflow as tf
3+
import numpy as np
4+
tf.set_random_seed(777) # reproducibility
5+
6+
dic = {0: 'h', 1: 'i', 2: 'e', 3: 'l', 4: 'o'}
7+
# Teach hello: hihell -> ihello
8+
x_data = [[0, 1, 0, 2, 3, 3]] # hihell
9+
x_one_hot = [[[1, 0, 0, 0, 0], # h 0
10+
[0, 1, 0, 0, 0], # i 1
11+
[1, 0, 0, 0, 0], # h 0
12+
[0, 0, 1, 0, 0], # e 2
13+
[0, 0, 0, 1, 0], # l 3
14+
[0, 0, 0, 1, 0]]] # l 3
15+
16+
y_data = [[1, 0, 2, 3, 3, 4]] # ihello
17+
18+
input_dim = 5 # one-hot size
19+
hidden_size = 5 # output from the LSTM 4 to directly predict onehot
20+
batch_size = 1 # one sentence
21+
sequence_length = 6 # |ihello| == 6
22+
23+
X = tf.placeholder(
24+
tf.float32, [None, sequence_length, hidden_size]) # X one-hot
25+
Y = tf.placeholder(tf.int32, [None, sequence_length]) # Y label
26+
27+
cell = tf.contrib.rnn.BasicLSTMCell(num_units=hidden_size, state_is_tuple=True)
28+
initial_state = cell.zero_state(batch_size, tf.float32)
29+
outputs, _states = tf.nn.dynamic_rnn(
30+
cell, X, initial_state=initial_state, dtype=tf.float32)
31+
32+
weights = tf.ones([batch_size, sequence_length])
33+
sequence_loss = tf.contrib.seq2seq.sequence_loss(
34+
logits=outputs, targets=Y, weights=weights)
35+
loss = tf.reduce_mean(sequence_loss)
36+
train = tf.train.GradientDescentOptimizer(learning_rate=0.1).minimize(loss)
37+
38+
prediction = tf.argmax(outputs, axis=2)
39+
40+
with tf.Session() as sess:
41+
sess.run(tf.global_variables_initializer())
42+
for i in range(2000):
43+
l, _ = sess.run([loss, train], feed_dict={X: x_one_hot, Y: y_data})
44+
result = sess.run(prediction, feed_dict={X: x_one_hot})
45+
print("loss:", l, "prediction: ", result, "true Y: ", y_data)
46+
47+
# print char using dic
48+
result_str = [dic[c] for c in np.squeeze(result)]
49+
print("Prediction str: ", ''.join(result_str))

lab-12-1-rnn.py

Lines changed: 0 additions & 41 deletions
This file was deleted.

0 commit comments

Comments
 (0)