Skip to content

Commit 0cc7b1a

Browse files
author
lichuang
committed
add lstm code
1 parent 39468b7 commit 0cc7b1a

File tree

5 files changed

+446
-0
lines changed

5 files changed

+446
-0
lines changed

lstm_code/iamtrask/lstm.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# coding:utf-8
2+
import copy, numpy as np
3+
np.random.seed(0)
4+
# compute sigmoid nonlinearity
5+
def sigmoid(x):
6+
output = 1/(1+np.exp(-x))
7+
return output
8+
9+
# convert output of sigmoid function to its derivative
10+
def sigmoid_output_to_derivative(output):
11+
return output*(1-output)
12+
13+
14+
# training dataset generation
15+
int2binary = {}
16+
binary_dim = 8
17+
18+
largest_number = pow(2,binary_dim)
19+
binary = np.unpackbits(
20+
np.array([range(largest_number)],dtype=np.uint8).T,axis=1)
21+
for i in range(largest_number):
22+
int2binary[i] = binary[i]
23+
24+
25+
# input variables
26+
alpha = 0.1
27+
input_dim = 2
28+
hidden_dim = 16
29+
output_dim = 1
30+
31+
32+
# initialize neural network weights
33+
synapse_0 = 2*np.random.random((input_dim,hidden_dim)) - 1
34+
synapse_1 = 2*np.random.random((hidden_dim,output_dim)) - 1
35+
synapse_h = 2*np.random.random((hidden_dim,hidden_dim)) - 1
36+
37+
synapse_0_update = np.zeros_like(synapse_0)
38+
synapse_1_update = np.zeros_like(synapse_1)
39+
synapse_h_update = np.zeros_like(synapse_h)
40+
41+
# training logic
42+
for j in range(10000):
43+
44+
# generate a simple addition problem (a + b = c)
45+
a_int = np.random.randint(largest_number/2) # int version
46+
a = int2binary[a_int] # binary encoding
47+
48+
b_int = np.random.randint(largest_number/2) # int version
49+
b = int2binary[b_int] # binary encoding
50+
51+
# true answer
52+
c_int = a_int + b_int
53+
c = int2binary[c_int]
54+
55+
# where we'll store our best guess (binary encoded)
56+
d = np.zeros_like(c)
57+
58+
overallError = 0
59+
60+
layer_2_deltas = list()
61+
layer_1_values = list()
62+
layer_1_values.append(np.zeros(hidden_dim))
63+
64+
# moving along the positions in the binary encoding
65+
for position in range(binary_dim):
66+
67+
# generate input and output
68+
X = np.array([[a[binary_dim - position - 1],b[binary_dim - position - 1]]])
69+
y = np.array([[c[binary_dim - position - 1]]]).T
70+
71+
# hidden layer (input ~+ prev_hidden)
72+
layer_1 = sigmoid(np.dot(X,synapse_0) + np.dot(layer_1_values[-1],synapse_h))
73+
74+
# output layer (new binary representation)
75+
layer_2 = sigmoid(np.dot(layer_1,synapse_1))
76+
77+
# did we miss?... if so by how much?
78+
layer_2_error = y - layer_2
79+
layer_2_deltas.append((layer_2_error)*sigmoid_output_to_derivative(layer_2))
80+
overallError += np.abs(layer_2_error[0])
81+
82+
# decode estimate so we can print it out
83+
d[binary_dim - position - 1] = np.round(layer_2[0][0])
84+
85+
# store hidden layer so we can use it in the next timestep
86+
layer_1_values.append(copy.deepcopy(layer_1))
87+
88+
future_layer_1_delta = np.zeros(hidden_dim)
89+
90+
for position in range(binary_dim):
91+
92+
X = np.array([[a[position],b[position]]])
93+
layer_1 = layer_1_values[-position-1]
94+
prev_layer_1 = layer_1_values[-position-2]
95+
96+
# error at output layer
97+
layer_2_delta = layer_2_deltas[-position-1]
98+
# error at hidden layer
99+
layer_1_delta = (future_layer_1_delta.dot(synapse_h.T) + \
100+
layer_2_delta.dot(synapse_1.T)) * sigmoid_output_to_derivative(layer_1)
101+
# let's update all our weights so we can try again
102+
synapse_1_update += np.atleast_2d(layer_1).T.dot(layer_2_delta)
103+
synapse_h_update += np.atleast_2d(prev_layer_1).T.dot(layer_1_delta)
104+
synapse_0_update += X.T.dot(layer_1_delta)
105+
106+
future_layer_1_delta = layer_1_delta
107+
108+
109+
synapse_0 += synapse_0_update * alpha
110+
synapse_1 += synapse_1_update * alpha
111+
synapse_h += synapse_h_update * alpha
112+
113+
synapse_0_update *= 0
114+
synapse_1_update *= 0
115+
synapse_h_update *= 0
116+
117+
# print out progress
118+
if(j % 1000 == 0):
119+
print "Error:" + str(overallError)
120+
print "Pred:" + str(d)
121+
print "True:" + str(c)
122+
out = 0
123+
for index,x in enumerate(reversed(d)):
124+
out += x*pow(2,index)
125+
print str(a_int) + " + " + str(b_int) + " = " + str(out)
126+
print "------------"

lstm_code/nicodjimenez/README.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# lstm
2+
A basic lstm network can be written from scratch in a few hundred lines of python, yet most of us have a hard time figuring out how lstm's actually work. The original Neural Computation [paper](https://www.google.com/url?sa=t&rct=j&q=&esrc=s&source=web&cd=3&cad=rja&uact=8&ved=0CDAQFjACahUKEwj1iZLX5efGAhVMpIgKHbv3DiI&url=http%3A%2F%2Fdeeplearning.cs.cmu.edu%2Fpdfs%2FHochreiter97_lstm.pdf&ei=ZuirVfW-GMzIogS777uQAg&usg=AFQjCNGoFvqrva4rDCNIcqNe_SiPL_VPxg&sig2=ZYnsGpdfHjRbK8xdr1thBg&bvm=bv.98197061,d.cGU) is too technical for non experts. Most blogs online on the topic seem to be written by people
3+
who have never implemented lstm's for people who will not implement them either. Other blogs are written by experts (like this [blog post](http://karpathy.github.io/2015/05/21/rnn-effectiveness/)) and lack simplified illustrative source code that actually does something. The [Apollo](https://github.com/Russell91/apollo) library built on top of caffe is terrific and features a fast lstm implementation. However, the downside of efficient implementations is that the source code is hard to follow.
4+
5+
This repo features a minimal lstm implementation for people that are curious about lstms to the point of wanting to know how lstm's might be implemented. The code here follows notational conventions set forth in [this](http://arxiv.org/abs/1506.00019)
6+
well written tutorial introduction. This article should be read before trying to understand this code (at least the part about lstm's). By running `python test.py` you will have a minimal example of an lstm network learning to predict an output sequence of numbers in [-1,1] by using a Euclidean loss on the first element of each node's hidden layer.
7+
8+
Play with code, add functionality, and try it on different datasets. Pull requests welcome.
9+
10+
Please read [my blog article](http://nicodjimenez.github.io/2014/08/08/lstm.html) if you want details on the backprop part of the code.
11+
12+
Also, check out a version of this code written in the D programming language by Mathias Baumann: https://github.com/Marenz/lstm

lstm_code/nicodjimenez/lstm.py

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
import random
2+
3+
import numpy as np
4+
import math
5+
6+
def sigmoid(x):
7+
return 1. / (1 + np.exp(-x))
8+
9+
# createst uniform random array w/ values in [a,b) and shape args
10+
def rand_arr(a, b, *args):
11+
np.random.seed(0)
12+
return np.random.rand(*args) * (b - a) + a
13+
14+
class LstmParam:
15+
def __init__(self, mem_cell_ct, x_dim):
16+
self.mem_cell_ct = mem_cell_ct
17+
self.x_dim = x_dim
18+
concat_len = x_dim + mem_cell_ct
19+
# weight matrices
20+
self.wg = rand_arr(-0.1, 0.1, mem_cell_ct, concat_len)
21+
self.wi = rand_arr(-0.1, 0.1, mem_cell_ct, concat_len)
22+
self.wf = rand_arr(-0.1, 0.1, mem_cell_ct, concat_len)
23+
self.wo = rand_arr(-0.1, 0.1, mem_cell_ct, concat_len)
24+
# bias terms
25+
self.bg = rand_arr(-0.1, 0.1, mem_cell_ct)
26+
self.bi = rand_arr(-0.1, 0.1, mem_cell_ct)
27+
self.bf = rand_arr(-0.1, 0.1, mem_cell_ct)
28+
self.bo = rand_arr(-0.1, 0.1, mem_cell_ct)
29+
# diffs (derivative of loss function w.r.t. all parameters)
30+
self.wg_diff = np.zeros((mem_cell_ct, concat_len))
31+
self.wi_diff = np.zeros((mem_cell_ct, concat_len))
32+
self.wf_diff = np.zeros((mem_cell_ct, concat_len))
33+
self.wo_diff = np.zeros((mem_cell_ct, concat_len))
34+
self.bg_diff = np.zeros(mem_cell_ct)
35+
self.bi_diff = np.zeros(mem_cell_ct)
36+
self.bf_diff = np.zeros(mem_cell_ct)
37+
self.bo_diff = np.zeros(mem_cell_ct)
38+
39+
def apply_diff(self, lr = 1):
40+
self.wg -= lr * self.wg_diff
41+
self.wi -= lr * self.wi_diff
42+
self.wf -= lr * self.wf_diff
43+
self.wo -= lr * self.wo_diff
44+
self.bg -= lr * self.bg_diff
45+
self.bi -= lr * self.bi_diff
46+
self.bf -= lr * self.bf_diff
47+
self.bo -= lr * self.bo_diff
48+
# reset diffs to zero
49+
self.wg_diff = np.zeros_like(self.wg)
50+
self.wi_diff = np.zeros_like(self.wi)
51+
self.wf_diff = np.zeros_like(self.wf)
52+
self.wo_diff = np.zeros_like(self.wo)
53+
self.bg_diff = np.zeros_like(self.bg)
54+
self.bi_diff = np.zeros_like(self.bi)
55+
self.bf_diff = np.zeros_like(self.bf)
56+
self.bo_diff = np.zeros_like(self.bo)
57+
58+
class LstmState:
59+
def __init__(self, mem_cell_ct, x_dim):
60+
self.g = np.zeros(mem_cell_ct)
61+
self.i = np.zeros(mem_cell_ct)
62+
self.f = np.zeros(mem_cell_ct)
63+
self.o = np.zeros(mem_cell_ct)
64+
self.s = np.zeros(mem_cell_ct)
65+
self.h = np.zeros(mem_cell_ct)
66+
self.bottom_diff_h = np.zeros_like(self.h)
67+
self.bottom_diff_s = np.zeros_like(self.s)
68+
self.bottom_diff_x = np.zeros(x_dim)
69+
70+
class LstmNode:
71+
def __init__(self, lstm_param, lstm_state):
72+
# store reference to parameters and to activations
73+
self.state = lstm_state
74+
self.param = lstm_param
75+
# non-recurrent input to node
76+
self.x = None
77+
# non-recurrent input concatenated with recurrent input
78+
self.xc = None
79+
80+
def bottom_data_is(self, x, s_prev = None, h_prev = None):
81+
# if this is the first lstm node in the network
82+
if s_prev == None: s_prev = np.zeros_like(self.state.s)
83+
if h_prev == None: h_prev = np.zeros_like(self.state.h)
84+
# save data for use in backprop
85+
self.s_prev = s_prev
86+
self.h_prev = h_prev
87+
88+
# concatenate x(t) and h(t-1)
89+
xc = np.hstack((x, h_prev))
90+
self.state.g = np.tanh(np.dot(self.param.wg, xc) + self.param.bg)
91+
self.state.i = sigmoid(np.dot(self.param.wi, xc) + self.param.bi)
92+
self.state.f = sigmoid(np.dot(self.param.wf, xc) + self.param.bf)
93+
self.state.o = sigmoid(np.dot(self.param.wo, xc) + self.param.bo)
94+
self.state.s = self.state.g * self.state.i + s_prev * self.state.f
95+
self.state.h = self.state.s * self.state.o
96+
self.x = x
97+
self.xc = xc
98+
99+
def top_diff_is(self, top_diff_h, top_diff_s):
100+
# notice that top_diff_s is carried along the constant error carousel
101+
ds = self.state.o * top_diff_h + top_diff_s
102+
do = self.state.s * top_diff_h
103+
di = self.state.g * ds
104+
dg = self.state.i * ds
105+
df = self.s_prev * ds
106+
107+
# diffs w.r.t. vector inside sigma / tanh function
108+
di_input = (1. - self.state.i) * self.state.i * di
109+
df_input = (1. - self.state.f) * self.state.f * df
110+
do_input = (1. - self.state.o) * self.state.o * do
111+
dg_input = (1. - self.state.g ** 2) * dg
112+
113+
# diffs w.r.t. inputs
114+
self.param.wi_diff += np.outer(di_input, self.xc)
115+
self.param.wf_diff += np.outer(df_input, self.xc)
116+
self.param.wo_diff += np.outer(do_input, self.xc)
117+
self.param.wg_diff += np.outer(dg_input, self.xc)
118+
self.param.bi_diff += di_input
119+
self.param.bf_diff += df_input
120+
self.param.bo_diff += do_input
121+
self.param.bg_diff += dg_input
122+
123+
# compute bottom diff
124+
dxc = np.zeros_like(self.xc)
125+
dxc += np.dot(self.param.wi.T, di_input)
126+
dxc += np.dot(self.param.wf.T, df_input)
127+
dxc += np.dot(self.param.wo.T, do_input)
128+
dxc += np.dot(self.param.wg.T, dg_input)
129+
130+
# save bottom diffs
131+
self.state.bottom_diff_s = ds * self.state.f
132+
self.state.bottom_diff_x = dxc[:self.param.x_dim]
133+
self.state.bottom_diff_h = dxc[self.param.x_dim:]
134+
135+
class LstmNetwork():
136+
def __init__(self, lstm_param):
137+
self.lstm_param = lstm_param
138+
self.lstm_node_list = []
139+
# input sequence
140+
self.x_list = []
141+
142+
def y_list_is(self, y_list, loss_layer):
143+
"""
144+
Updates diffs by setting target sequence
145+
with corresponding loss layer.
146+
Will *NOT* update parameters. To update parameters,
147+
call self.lstm_param.apply_diff()
148+
"""
149+
assert len(y_list) == len(self.x_list)
150+
idx = len(self.x_list) - 1
151+
# first node only gets diffs from label ...
152+
loss = loss_layer.loss(self.lstm_node_list[idx].state.h, y_list[idx])
153+
diff_h = loss_layer.bottom_diff(self.lstm_node_list[idx].state.h, y_list[idx])
154+
# here s is not affecting loss due to h(t+1), hence we set equal to zero
155+
diff_s = np.zeros(self.lstm_param.mem_cell_ct)
156+
self.lstm_node_list[idx].top_diff_is(diff_h, diff_s)
157+
idx -= 1
158+
159+
### ... following nodes also get diffs from next nodes, hence we add diffs to diff_h
160+
### we also propagate error along constant error carousel using diff_s
161+
while idx >= 0:
162+
loss += loss_layer.loss(self.lstm_node_list[idx].state.h, y_list[idx])
163+
diff_h = loss_layer.bottom_diff(self.lstm_node_list[idx].state.h, y_list[idx])
164+
diff_h += self.lstm_node_list[idx + 1].state.bottom_diff_h
165+
diff_s = self.lstm_node_list[idx + 1].state.bottom_diff_s
166+
self.lstm_node_list[idx].top_diff_is(diff_h, diff_s)
167+
idx -= 1
168+
169+
return loss
170+
171+
def x_list_clear(self):
172+
self.x_list = []
173+
174+
def x_list_add(self, x):
175+
self.x_list.append(x)
176+
if len(self.x_list) > len(self.lstm_node_list):
177+
# need to add new lstm node, create new state mem
178+
lstm_state = LstmState(self.lstm_param.mem_cell_ct, self.lstm_param.x_dim)
179+
self.lstm_node_list.append(LstmNode(self.lstm_param, lstm_state))
180+
181+
# get index of most recent x input
182+
idx = len(self.x_list) - 1
183+
if idx == 0:
184+
# no recurrent inputs yet
185+
self.lstm_node_list[idx].bottom_data_is(x)
186+
else:
187+
s_prev = self.lstm_node_list[idx - 1].state.s
188+
h_prev = self.lstm_node_list[idx - 1].state.h
189+
self.lstm_node_list[idx].bottom_data_is(x, s_prev, h_prev)
190+

lstm_code/nicodjimenez/test.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import numpy as np
2+
import sys
3+
4+
from lstm import LstmParam, LstmNetwork
5+
6+
class ToyLossLayer:
7+
"""
8+
Computes square loss with first element of hidden layer array.
9+
"""
10+
@classmethod
11+
def loss(self, pred, label):
12+
return (pred[0] - label) ** 2
13+
14+
@classmethod
15+
def bottom_diff(self, pred, label):
16+
diff = np.zeros_like(pred)
17+
diff[0] = 2 * (pred[0] - label)
18+
return diff
19+
20+
def example_0():
21+
# learns to repeat simple sequence from random inputs
22+
np.random.seed(0)
23+
24+
# parameters for input data dimension and lstm cell count
25+
mem_cell_ct = 100
26+
x_dim = 50
27+
concat_len = x_dim + mem_cell_ct
28+
lstm_param = LstmParam(mem_cell_ct, x_dim)
29+
lstm_net = LstmNetwork(lstm_param)
30+
y_list = [-0.5,0.2,0.1, -0.5]
31+
input_val_arr = [np.random.random(x_dim) for _ in y_list]
32+
33+
for cur_iter in range(100):
34+
print "cur iter: ", cur_iter
35+
print "input_val_arr=", input_val_arr
36+
print "y_list=", y_list
37+
for ind in range(len(y_list)):
38+
lstm_net.x_list_add(input_val_arr[ind])
39+
print "y_pred[%d] : %f" % (ind, lstm_net.lstm_node_list[ind].state.h[0])
40+
41+
loss = lstm_net.y_list_is(y_list, ToyLossLayer)
42+
print "loss: ", loss
43+
lstm_param.apply_diff(lr=0.1)
44+
lstm_net.x_list_clear()
45+
46+
if __name__ == "__main__":
47+
example_0()
48+

0 commit comments

Comments
 (0)