88import tflearn
99
1010max_seq_len = 8
11- learning_rate = 0.001
11+ learning_rate = 0.01
12+ id_word_dict = {}
1213
1314# 得到了单词转id的词典是word_id_dict, 最大单词id是max_word_id
1415def init_word_id_dict ():
@@ -43,9 +44,13 @@ def init_word_id_dict():
4344
4445 uuid = 1
4546
46- max_word_id = 1500
47+ max_word_id = 2000
4748 for (word , freq ) in vocab_dict :
4849 word_id_dict [word ] = uuid
50+ id_word_dict [uuid ] = word
51+ #if freq > 20:
52+ # print word, uuid, freq
53+ print word , uuid , freq
4954 uuid = uuid + 1
5055 if uuid > max_word_id :
5156 break
@@ -63,7 +68,7 @@ def accuracy(y_pred, y_true, x_in):
6368 pred_idx = tf .to_int32 (tf .argmax (y_pred , 2 ))
6469 return tf .reduce_mean (tf .cast (tf .equal (pred_idx , y_true ), tf .float32 ), name = 'acc' )
6570
66- def create_model (max_word_id ):
71+ def create_model (max_word_id , is_test = False ):
6772 GO_VALUE = max_word_id + 1
6873 network = tflearn .input_data (shape = [None , max_seq_len + max_seq_len ], dtype = tf .int32 , name = "XY" )
6974 encoder_inputs = tf .slice (network , [0 , 0 ], [- 1 , max_seq_len ], name = "enc_in" )
@@ -75,7 +80,7 @@ def create_model(max_word_id):
7580 num_encoder_symbols = max_word_id + 1 # 从0起始
7681 num_decoder_symbols = max_word_id + 2 # 包括GO
7782
78- cell = rnn_cell .BasicLSTMCell (max_seq_len + max_seq_len , state_is_tuple = True )
83+ cell = rnn_cell .BasicLSTMCell (16 * max_seq_len , state_is_tuple = True )
7984
8085 model_outputs , states = seq2seq .embedding_rnn_seq2seq (
8186 encoder_inputs ,
@@ -84,7 +89,7 @@ def create_model(max_word_id):
8489 num_encoder_symbols = num_encoder_symbols ,
8590 num_decoder_symbols = num_decoder_symbols ,
8691 embedding_size = max_word_id ,
87- feed_previous = False )
92+ feed_previous = is_test )
8893
8994 network = tf .pack (model_outputs , axis = 1 )
9095
@@ -107,12 +112,22 @@ def create_model(max_word_id):
107112 print "create DNN model finish"
108113 return model
109114
115+ def print_sentence (list , msg ):
116+ sentence = msg
117+ for item in list :
118+ if item != 0 :
119+ sentence = sentence + id_word_dict [item ]
120+ print sentence
110121
111122if __name__ == '__main__' :
123+ if len (sys .argv ) > 1 and sys .argv [1 ] == 'test' :
124+ is_test = True
125+ else :
126+ is_test = False
112127 (word_id_dict , max_word_id ) = init_word_id_dict ()
113128 print "max_word_id =" , max_word_id
114129
115- model = create_model (max_word_id )
130+ model = create_model (max_word_id , is_test )
116131
117132 threshold = max_seq_len
118133 file_object = open ("chat_dev.data" , "r" )
@@ -138,14 +153,20 @@ def create_model(max_word_id):
138153 # 保证连续的话才参与训练
139154 if last_line_no != 0 and last_line_no == cur_line_no - 1 :
140155 question_id_list = []
156+ question = ""
157+ answer = ""
141158 question_array = np .zeros (max_seq_len + max_seq_len )
142159 answer_array = np .zeros (max_seq_len )
143160 idx = 0
161+ question_has_word = False
162+ answer_has_word = False
144163 for word in last_words :
145164 if len (word )> 0 and word_id_dict .has_key (word ):
146165 word_id = word_id_dict [word ]
147166 question_id_list .append (word_id )
167+ question = question + word
148168 question_array [idx ] = word_id
169+ question_has_word = True
149170 idx = idx + 1
150171 for i in range (max_seq_len - len (question_id_list )):
151172 question_id_list .append (0 )
@@ -157,20 +178,21 @@ def create_model(max_word_id):
157178 if len (word )> 0 and word_id_dict .has_key (word ):
158179 word_id = word_id_dict [word ]
159180 answer_id_list .append (word_id )
181+ answer = answer + word
160182 question_array [max_seq_len + idx ] = word_id
161183 answer_array [idx ] = word_id
184+ answer_has_word = True
162185 idx = idx + 1
163186 for i in range (2 * max_seq_len - len (question_id_list )):
164187 answer_id_list .append (0 )
165188 question_id_list .extend (answer_id_list )
166189
167- XY .append (question_array )
168- Y .append (answer_array )
169- sample_count = sample_count + 1
170-
171- #if sample_count > 0:
172- # break
173-
190+ if question_has_word and answer_has_word :
191+ #print "question =", question
192+ #print "answer =", answer
193+ XY .append (question_array )
194+ Y .append (answer_array )
195+ sample_count = sample_count + 1
174196
175197 last_words = words
176198 last_line = line
@@ -180,29 +202,39 @@ def create_model(max_word_id):
180202 break
181203 file_object .close ()
182204
183- model .fit (
184- XY ,
185- Y ,
186- n_epoch = 100 ,
187- validation_set = 0.01 ,
188- batch_size = 1 ,
189- shuffle = True ,
190- show_metric = True ,
191- snapshot_step = 5000 ,
192- snapshot_epoch = False ,
193- run_id = "my_lstm_test" )
205+ if not is_test :
206+ model .fit (
207+ XY ,
208+ Y ,
209+ n_epoch = 3000 ,
210+ validation_set = 0.01 ,
211+ batch_size = 64 ,
212+ shuffle = True ,
213+ show_metric = True ,
214+ snapshot_step = 5000 ,
215+ snapshot_epoch = False ,
216+ run_id = "my_lstm_test" )
194217
195- model .save ("./weights" )
196- #model.load("./weights")
218+ model .save ("./weights" )
219+ else :
220+ model .load ("./weights" )
197221
198222
199223 # predict
200- TEST_XY = [XY [0 ]]
201- res = model .predict (TEST_XY )
202- res = np .array (res )
203- num_decoder_symbols = max_word_id + 2
204- y = res .reshape (max_seq_len , num_decoder_symbols )
205- prediction = np .argmax (y , axis = 1 )
206- print TEST_XY
207- print "desire =" , Y [0 ]
208- print "prediction =" , prediction
224+ for i in range (100 ):
225+ TEST_XY = [XY [i ]]
226+ TEST_XY [0 ][max_seq_len :2 * max_seq_len ] = 0
227+ #TEST_XY[0][0:2*max_seq_len] = 0
228+ #TEST_XY[0][0] = 5
229+ #TEST_XY[0][1] = 4
230+ #TEST_XY[0][2] = 109
231+
232+ res = model .predict (TEST_XY )
233+ res = np .array (res )
234+ num_decoder_symbols = max_word_id + 2
235+ y = res .reshape (max_seq_len , num_decoder_symbols )
236+ prediction = np .argmax (y , axis = 1 )
237+ if 0 != np .sum (prediction ):
238+ print_sentence (TEST_XY [0 ], "input " )
239+ print_sentence (Y [i ], "desire " )
240+ print_sentence (prediction , "prediction " )
0 commit comments