77from tensorflow .contrib .legacy_seq2seq .python .ops import seq2seq
88import word_token
99import jieba
10+ import random
1011
1112# 输入序列长度
1213input_seq_len = 5
2021EOS_ID = 2
2122# LSTM神经元size
2223size = 8
23- # 最大输入符号数
24- num_encoder_symbols = 32
25- # 最大输出符号数
26- num_decoder_symbols = 32
2724# 初始学习率
2825init_learning_rate = 1
26+ # 在样本中出现频率超过这个值才会进入词表
27+ min_freq = 10
2928
3029wordToken = word_token .WordToken ()
3130
3231# 放在全局的位置,为了动态算出num_encoder_symbols和num_decoder_symbols
33- max_token_id = wordToken .load_file_list (['./samples/question' , './samples/answer' ])
32+ max_token_id = wordToken .load_file_list (['./samples/question' , './samples/answer' ], min_freq )
3433num_encoder_symbols = max_token_id + 5
3534num_decoder_symbols = max_token_id + 5
3635
@@ -59,14 +58,15 @@ def get_train_set():
5958
6059 question_id_list = get_id_list_from (question )
6160 answer_id_list = get_id_list_from (answer )
62- answer_id_list .append (EOS_ID )
63- train_set .append ([question_id_list , answer_id_list ])
61+ if len (question_id_list ) > 0 and len (answer_id_list ) > 0 :
62+ answer_id_list .append (EOS_ID )
63+ train_set .append ([question_id_list , answer_id_list ])
6464 else :
6565 break
6666 return train_set
6767
6868
69- def get_samples (train_set ):
69+ def get_samples (train_set , batch_num ):
7070 """构造样本数据
7171
7272 :return:
@@ -78,7 +78,12 @@ def get_samples(train_set):
7878 # train_set = [[[5, 7, 9], [11, 13, 15, EOS_ID]], [[7, 9, 11], [13, 15, 17, EOS_ID]], [[15, 17, 19], [21, 23, 25, EOS_ID]]]
7979 raw_encoder_input = []
8080 raw_decoder_input = []
81- for sample in train_set :
81+ if batch_num >= len (train_set ):
82+ batch_train_set = train_set
83+ else :
84+ random_start = random .randint (0 , len (train_set )- batch_num )
85+ batch_train_set = train_set [random_start :random_start + batch_num ]
86+ for sample in batch_train_set :
8287 raw_encoder_input .append ([PAD_ID ] * (input_seq_len - len (sample [0 ])) + sample [0 ])
8388 raw_decoder_input .append ([GO_ID ] + sample [1 ] + [PAD_ID ] * (output_seq_len - len (sample [1 ]) - 1 ))
8489
@@ -163,23 +168,22 @@ def train():
163168 train_set = get_train_set ()
164169 with tf .Session () as sess :
165170
166- sample_encoder_inputs , sample_decoder_inputs , sample_target_weights = get_samples (train_set )
167171 encoder_inputs , decoder_inputs , target_weights , outputs , loss , update , saver , learning_rate_decay_op , learning_rate = get_model ()
168172
169- input_feed = {}
170- for l in xrange (input_seq_len ):
171- input_feed [encoder_inputs [l ].name ] = sample_encoder_inputs [l ]
172- for l in xrange (output_seq_len ):
173- input_feed [decoder_inputs [l ].name ] = sample_decoder_inputs [l ]
174- input_feed [target_weights [l ].name ] = sample_target_weights [l ]
175- input_feed [decoder_inputs [output_seq_len ].name ] = np .zeros ([len (sample_decoder_inputs [0 ])], dtype = np .int32 )
176-
177173 # 全部变量初始化
178174 sess .run (tf .global_variables_initializer ())
179175
180176 # 训练很多次迭代,每隔10次打印一次loss,可以看情况直接ctrl+c停止
181177 previous_losses = []
182- for step in xrange (20700 ):
178+ for step in xrange (20000 ):
179+ sample_encoder_inputs , sample_decoder_inputs , sample_target_weights = get_samples (train_set , 1000 )
180+ input_feed = {}
181+ for l in xrange (input_seq_len ):
182+ input_feed [encoder_inputs [l ].name ] = sample_encoder_inputs [l ]
183+ for l in xrange (output_seq_len ):
184+ input_feed [decoder_inputs [l ].name ] = sample_decoder_inputs [l ]
185+ input_feed [target_weights [l ].name ] = sample_target_weights [l ]
186+ input_feed [decoder_inputs [output_seq_len ].name ] = np .zeros ([len (sample_decoder_inputs [0 ])], dtype = np .int32 )
183187 [loss_ret , _ ] = sess .run ([loss , update ], input_feed )
184188 if step % 10 == 0 :
185189 print 'step=' , step , 'loss=' , loss_ret , 'learning_rate=' , learning_rate .eval ()
0 commit comments