|
| 1 | +import random |
| 2 | + |
| 3 | +import numpy as np |
| 4 | +import math |
| 5 | + |
| 6 | +def sigmoid(x): |
| 7 | + return 1. / (1 + np.exp(-x)) |
| 8 | + |
| 9 | +# createst uniform random array w/ values in [a,b) and shape args |
| 10 | +def rand_arr(a, b, *args): |
| 11 | + np.random.seed(0) |
| 12 | + return np.random.rand(*args) * (b - a) + a |
| 13 | + |
| 14 | +class LstmParam: |
| 15 | + def __init__(self, mem_cell_ct, x_dim): |
| 16 | + self.mem_cell_ct = mem_cell_ct |
| 17 | + self.x_dim = x_dim |
| 18 | + concat_len = x_dim + mem_cell_ct |
| 19 | + # weight matrices |
| 20 | + self.wg = rand_arr(-0.1, 0.1, mem_cell_ct, concat_len) |
| 21 | + self.wi = rand_arr(-0.1, 0.1, mem_cell_ct, concat_len) |
| 22 | + self.wf = rand_arr(-0.1, 0.1, mem_cell_ct, concat_len) |
| 23 | + self.wo = rand_arr(-0.1, 0.1, mem_cell_ct, concat_len) |
| 24 | + # bias terms |
| 25 | + self.bg = rand_arr(-0.1, 0.1, mem_cell_ct) |
| 26 | + self.bi = rand_arr(-0.1, 0.1, mem_cell_ct) |
| 27 | + self.bf = rand_arr(-0.1, 0.1, mem_cell_ct) |
| 28 | + self.bo = rand_arr(-0.1, 0.1, mem_cell_ct) |
| 29 | + # diffs (derivative of loss function w.r.t. all parameters) |
| 30 | + self.wg_diff = np.zeros((mem_cell_ct, concat_len)) |
| 31 | + self.wi_diff = np.zeros((mem_cell_ct, concat_len)) |
| 32 | + self.wf_diff = np.zeros((mem_cell_ct, concat_len)) |
| 33 | + self.wo_diff = np.zeros((mem_cell_ct, concat_len)) |
| 34 | + self.bg_diff = np.zeros(mem_cell_ct) |
| 35 | + self.bi_diff = np.zeros(mem_cell_ct) |
| 36 | + self.bf_diff = np.zeros(mem_cell_ct) |
| 37 | + self.bo_diff = np.zeros(mem_cell_ct) |
| 38 | + |
| 39 | + def apply_diff(self, lr = 1): |
| 40 | + self.wg -= lr * self.wg_diff |
| 41 | + self.wi -= lr * self.wi_diff |
| 42 | + self.wf -= lr * self.wf_diff |
| 43 | + self.wo -= lr * self.wo_diff |
| 44 | + self.bg -= lr * self.bg_diff |
| 45 | + self.bi -= lr * self.bi_diff |
| 46 | + self.bf -= lr * self.bf_diff |
| 47 | + self.bo -= lr * self.bo_diff |
| 48 | + # reset diffs to zero |
| 49 | + self.wg_diff = np.zeros_like(self.wg) |
| 50 | + self.wi_diff = np.zeros_like(self.wi) |
| 51 | + self.wf_diff = np.zeros_like(self.wf) |
| 52 | + self.wo_diff = np.zeros_like(self.wo) |
| 53 | + self.bg_diff = np.zeros_like(self.bg) |
| 54 | + self.bi_diff = np.zeros_like(self.bi) |
| 55 | + self.bf_diff = np.zeros_like(self.bf) |
| 56 | + self.bo_diff = np.zeros_like(self.bo) |
| 57 | + |
| 58 | +class LstmState: |
| 59 | + def __init__(self, mem_cell_ct, x_dim): |
| 60 | + self.g = np.zeros(mem_cell_ct) |
| 61 | + self.i = np.zeros(mem_cell_ct) |
| 62 | + self.f = np.zeros(mem_cell_ct) |
| 63 | + self.o = np.zeros(mem_cell_ct) |
| 64 | + self.s = np.zeros(mem_cell_ct) |
| 65 | + self.h = np.zeros(mem_cell_ct) |
| 66 | + self.bottom_diff_h = np.zeros_like(self.h) |
| 67 | + self.bottom_diff_s = np.zeros_like(self.s) |
| 68 | + self.bottom_diff_x = np.zeros(x_dim) |
| 69 | + |
| 70 | +class LstmNode: |
| 71 | + def __init__(self, lstm_param, lstm_state): |
| 72 | + # store reference to parameters and to activations |
| 73 | + self.state = lstm_state |
| 74 | + self.param = lstm_param |
| 75 | + # non-recurrent input to node |
| 76 | + self.x = None |
| 77 | + # non-recurrent input concatenated with recurrent input |
| 78 | + self.xc = None |
| 79 | + |
| 80 | + def bottom_data_is(self, x, s_prev = None, h_prev = None): |
| 81 | + # if this is the first lstm node in the network |
| 82 | + if s_prev == None: s_prev = np.zeros_like(self.state.s) |
| 83 | + if h_prev == None: h_prev = np.zeros_like(self.state.h) |
| 84 | + # save data for use in backprop |
| 85 | + self.s_prev = s_prev |
| 86 | + self.h_prev = h_prev |
| 87 | + |
| 88 | + # concatenate x(t) and h(t-1) |
| 89 | + xc = np.hstack((x, h_prev)) |
| 90 | + self.state.g = np.tanh(np.dot(self.param.wg, xc) + self.param.bg) |
| 91 | + self.state.i = sigmoid(np.dot(self.param.wi, xc) + self.param.bi) |
| 92 | + self.state.f = sigmoid(np.dot(self.param.wf, xc) + self.param.bf) |
| 93 | + self.state.o = sigmoid(np.dot(self.param.wo, xc) + self.param.bo) |
| 94 | + self.state.s = self.state.g * self.state.i + s_prev * self.state.f |
| 95 | + self.state.h = self.state.s * self.state.o |
| 96 | + self.x = x |
| 97 | + self.xc = xc |
| 98 | + |
| 99 | + def top_diff_is(self, top_diff_h, top_diff_s): |
| 100 | + # notice that top_diff_s is carried along the constant error carousel |
| 101 | + ds = self.state.o * top_diff_h + top_diff_s |
| 102 | + do = self.state.s * top_diff_h |
| 103 | + di = self.state.g * ds |
| 104 | + dg = self.state.i * ds |
| 105 | + df = self.s_prev * ds |
| 106 | + |
| 107 | + # diffs w.r.t. vector inside sigma / tanh function |
| 108 | + di_input = (1. - self.state.i) * self.state.i * di |
| 109 | + df_input = (1. - self.state.f) * self.state.f * df |
| 110 | + do_input = (1. - self.state.o) * self.state.o * do |
| 111 | + dg_input = (1. - self.state.g ** 2) * dg |
| 112 | + |
| 113 | + # diffs w.r.t. inputs |
| 114 | + self.param.wi_diff += np.outer(di_input, self.xc) |
| 115 | + self.param.wf_diff += np.outer(df_input, self.xc) |
| 116 | + self.param.wo_diff += np.outer(do_input, self.xc) |
| 117 | + self.param.wg_diff += np.outer(dg_input, self.xc) |
| 118 | + self.param.bi_diff += di_input |
| 119 | + self.param.bf_diff += df_input |
| 120 | + self.param.bo_diff += do_input |
| 121 | + self.param.bg_diff += dg_input |
| 122 | + |
| 123 | + # compute bottom diff |
| 124 | + dxc = np.zeros_like(self.xc) |
| 125 | + dxc += np.dot(self.param.wi.T, di_input) |
| 126 | + dxc += np.dot(self.param.wf.T, df_input) |
| 127 | + dxc += np.dot(self.param.wo.T, do_input) |
| 128 | + dxc += np.dot(self.param.wg.T, dg_input) |
| 129 | + |
| 130 | + # save bottom diffs |
| 131 | + self.state.bottom_diff_s = ds * self.state.f |
| 132 | + self.state.bottom_diff_x = dxc[:self.param.x_dim] |
| 133 | + self.state.bottom_diff_h = dxc[self.param.x_dim:] |
| 134 | + |
| 135 | +class LstmNetwork(): |
| 136 | + def __init__(self, lstm_param): |
| 137 | + self.lstm_param = lstm_param |
| 138 | + self.lstm_node_list = [] |
| 139 | + # input sequence |
| 140 | + self.x_list = [] |
| 141 | + |
| 142 | + def y_list_is(self, y_list, loss_layer): |
| 143 | + """ |
| 144 | + Updates diffs by setting target sequence |
| 145 | + with corresponding loss layer. |
| 146 | + Will *NOT* update parameters. To update parameters, |
| 147 | + call self.lstm_param.apply_diff() |
| 148 | + """ |
| 149 | + assert len(y_list) == len(self.x_list) |
| 150 | + idx = len(self.x_list) - 1 |
| 151 | + # first node only gets diffs from label ... |
| 152 | + loss = loss_layer.loss(self.lstm_node_list[idx].state.h, y_list[idx]) |
| 153 | + diff_h = loss_layer.bottom_diff(self.lstm_node_list[idx].state.h, y_list[idx]) |
| 154 | + # here s is not affecting loss due to h(t+1), hence we set equal to zero |
| 155 | + diff_s = np.zeros(self.lstm_param.mem_cell_ct) |
| 156 | + self.lstm_node_list[idx].top_diff_is(diff_h, diff_s) |
| 157 | + idx -= 1 |
| 158 | + |
| 159 | + ### ... following nodes also get diffs from next nodes, hence we add diffs to diff_h |
| 160 | + ### we also propagate error along constant error carousel using diff_s |
| 161 | + while idx >= 0: |
| 162 | + loss += loss_layer.loss(self.lstm_node_list[idx].state.h, y_list[idx]) |
| 163 | + diff_h = loss_layer.bottom_diff(self.lstm_node_list[idx].state.h, y_list[idx]) |
| 164 | + diff_h += self.lstm_node_list[idx + 1].state.bottom_diff_h |
| 165 | + diff_s = self.lstm_node_list[idx + 1].state.bottom_diff_s |
| 166 | + self.lstm_node_list[idx].top_diff_is(diff_h, diff_s) |
| 167 | + idx -= 1 |
| 168 | + |
| 169 | + return loss |
| 170 | + |
| 171 | + def x_list_clear(self): |
| 172 | + self.x_list = [] |
| 173 | + |
| 174 | + def x_list_add(self, x): |
| 175 | + self.x_list.append(x) |
| 176 | + if len(self.x_list) > len(self.lstm_node_list): |
| 177 | + # need to add new lstm node, create new state mem |
| 178 | + lstm_state = LstmState(self.lstm_param.mem_cell_ct, self.lstm_param.x_dim) |
| 179 | + self.lstm_node_list.append(LstmNode(self.lstm_param, lstm_state)) |
| 180 | + |
| 181 | + # get index of most recent x input |
| 182 | + idx = len(self.x_list) - 1 |
| 183 | + if idx == 0: |
| 184 | + # no recurrent inputs yet |
| 185 | + self.lstm_node_list[idx].bottom_data_is(x) |
| 186 | + else: |
| 187 | + s_prev = self.lstm_node_list[idx - 1].state.s |
| 188 | + h_prev = self.lstm_node_list[idx - 1].state.h |
| 189 | + self.lstm_node_list[idx].bottom_data_is(x, s_prev, h_prev) |
| 190 | + |
0 commit comments