|
| 1 | +# Lab 12 RNN |
| 2 | +import tensorflow as tf |
| 3 | +import numpy as np |
| 4 | +tf.set_random_seed(777) # reproducibility |
| 5 | + |
| 6 | +dic = {0: 'h', 1: 'i', 2: 'e', 3: 'l', 4: 'o'} |
| 7 | +# Teach hello: hihell -> ihello |
| 8 | +x_data = [[0, 1, 0, 2, 3, 3]] # hihell |
| 9 | +x_one_hot = [[[1, 0, 0, 0, 0], # h 0 |
| 10 | + [0, 1, 0, 0, 0], # i 1 |
| 11 | + [1, 0, 0, 0, 0], # h 0 |
| 12 | + [0, 0, 1, 0, 0], # e 2 |
| 13 | + [0, 0, 0, 1, 0], # l 3 |
| 14 | + [0, 0, 0, 1, 0]]] # l 3 |
| 15 | + |
| 16 | +y_data = [[1, 0, 2, 3, 3, 4]] # ihello |
| 17 | + |
| 18 | +input_dim = 5 # one-hot size |
| 19 | +hidden_size = 5 # output from the LSTM 4 to directly predict onehot |
| 20 | +batch_size = 1 # one sentence |
| 21 | +sequence_length = 6 # |ihello| == 6 |
| 22 | + |
| 23 | +X = tf.placeholder( |
| 24 | + tf.float32, [None, sequence_length, hidden_size]) # X one-hot |
| 25 | +Y = tf.placeholder(tf.int32, [None, sequence_length]) # Y label |
| 26 | + |
| 27 | +cell = tf.contrib.rnn.BasicLSTMCell(num_units=hidden_size, state_is_tuple=True) |
| 28 | +initial_state = cell.zero_state(batch_size, tf.float32) |
| 29 | +outputs, _states = tf.nn.dynamic_rnn( |
| 30 | + cell, X, initial_state=initial_state, dtype=tf.float32) |
| 31 | + |
| 32 | +weights = tf.ones([batch_size, sequence_length]) |
| 33 | +sequence_loss = tf.contrib.seq2seq.sequence_loss( |
| 34 | + logits=outputs, targets=Y, weights=weights) |
| 35 | +loss = tf.reduce_mean(sequence_loss) |
| 36 | +train = tf.train.GradientDescentOptimizer(learning_rate=0.1).minimize(loss) |
| 37 | + |
| 38 | +prediction = tf.argmax(outputs, axis=2) |
| 39 | + |
| 40 | +with tf.Session() as sess: |
| 41 | + sess.run(tf.global_variables_initializer()) |
| 42 | + for i in range(2000): |
| 43 | + l, _ = sess.run([loss, train], feed_dict={X: x_one_hot, Y: y_data}) |
| 44 | + result = sess.run(prediction, feed_dict={X: x_one_hot}) |
| 45 | + print("loss:", l, "prediction: ", result, "true Y: ", y_data) |
| 46 | + |
| 47 | + # print char using dic |
| 48 | + result_str = [dic[c] for c in np.squeeze(result)] |
| 49 | + print("Prediction str: ", ''.join(result_str)) |
0 commit comments