Skip to content

Commit 51002cb

Browse files
author
李闯
committed
chatbotv5解决内存问题
1 parent 19fbb7f commit 51002cb

2 files changed

Lines changed: 26 additions & 21 deletions

File tree

chatbotv5/demo.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from tensorflow.contrib.legacy_seq2seq.python.ops import seq2seq
88
import word_token
99
import jieba
10+
import random
1011

1112
# 输入序列长度
1213
input_seq_len = 5
@@ -20,17 +21,15 @@
2021
EOS_ID = 2
2122
# LSTM神经元size
2223
size = 8
23-
# 最大输入符号数
24-
num_encoder_symbols = 32
25-
# 最大输出符号数
26-
num_decoder_symbols = 32
2724
# 初始学习率
2825
init_learning_rate = 1
26+
# 在样本中出现频率超过这个值才会进入词表
27+
min_freq = 10
2928

3029
wordToken = word_token.WordToken()
3130

3231
# 放在全局的位置,为了动态算出num_encoder_symbols和num_decoder_symbols
33-
max_token_id = wordToken.load_file_list(['./samples/question', './samples/answer'])
32+
max_token_id = wordToken.load_file_list(['./samples/question', './samples/answer'], min_freq)
3433
num_encoder_symbols = max_token_id + 5
3534
num_decoder_symbols = max_token_id + 5
3635

@@ -59,14 +58,15 @@ def get_train_set():
5958

6059
question_id_list = get_id_list_from(question)
6160
answer_id_list = get_id_list_from(answer)
62-
answer_id_list.append(EOS_ID)
63-
train_set.append([question_id_list, answer_id_list])
61+
if len(question_id_list) > 0 and len(answer_id_list) > 0:
62+
answer_id_list.append(EOS_ID)
63+
train_set.append([question_id_list, answer_id_list])
6464
else:
6565
break
6666
return train_set
6767

6868

69-
def get_samples(train_set):
69+
def get_samples(train_set, batch_num):
7070
"""构造样本数据
7171
7272
:return:
@@ -78,7 +78,12 @@ def get_samples(train_set):
7878
# train_set = [[[5, 7, 9], [11, 13, 15, EOS_ID]], [[7, 9, 11], [13, 15, 17, EOS_ID]], [[15, 17, 19], [21, 23, 25, EOS_ID]]]
7979
raw_encoder_input = []
8080
raw_decoder_input = []
81-
for sample in train_set:
81+
if batch_num >= len(train_set):
82+
batch_train_set = train_set
83+
else:
84+
random_start = random.randint(0, len(train_set)-batch_num)
85+
batch_train_set = train_set[random_start:random_start+batch_num]
86+
for sample in batch_train_set:
8287
raw_encoder_input.append([PAD_ID] * (input_seq_len - len(sample[0])) + sample[0])
8388
raw_decoder_input.append([GO_ID] + sample[1] + [PAD_ID] * (output_seq_len - len(sample[1]) - 1))
8489

@@ -163,23 +168,22 @@ def train():
163168
train_set = get_train_set()
164169
with tf.Session() as sess:
165170

166-
sample_encoder_inputs, sample_decoder_inputs, sample_target_weights = get_samples(train_set)
167171
encoder_inputs, decoder_inputs, target_weights, outputs, loss, update, saver, learning_rate_decay_op, learning_rate = get_model()
168172

169-
input_feed = {}
170-
for l in xrange(input_seq_len):
171-
input_feed[encoder_inputs[l].name] = sample_encoder_inputs[l]
172-
for l in xrange(output_seq_len):
173-
input_feed[decoder_inputs[l].name] = sample_decoder_inputs[l]
174-
input_feed[target_weights[l].name] = sample_target_weights[l]
175-
input_feed[decoder_inputs[output_seq_len].name] = np.zeros([len(sample_decoder_inputs[0])], dtype=np.int32)
176-
177173
# 全部变量初始化
178174
sess.run(tf.global_variables_initializer())
179175

180176
# 训练很多次迭代,每隔10次打印一次loss,可以看情况直接ctrl+c停止
181177
previous_losses = []
182-
for step in xrange(20700):
178+
for step in xrange(20000):
179+
sample_encoder_inputs, sample_decoder_inputs, sample_target_weights = get_samples(train_set, 1000)
180+
input_feed = {}
181+
for l in xrange(input_seq_len):
182+
input_feed[encoder_inputs[l].name] = sample_encoder_inputs[l]
183+
for l in xrange(output_seq_len):
184+
input_feed[decoder_inputs[l].name] = sample_decoder_inputs[l]
185+
input_feed[target_weights[l].name] = sample_target_weights[l]
186+
input_feed[decoder_inputs[output_seq_len].name] = np.zeros([len(sample_decoder_inputs[0])], dtype=np.int32)
183187
[loss_ret, _] = sess.run([loss, update], input_feed)
184188
if step % 10 == 0:
185189
print 'step=', step, 'loss=', loss_ret, 'learning_rate=', learning_rate.eval()

chatbotv5/word_token.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def __init__(self):
1111
self.id2word_dict = {}
1212

1313

14-
def load_file_list(self, file_list):
14+
def load_file_list(self, file_list, min_freq):
1515
"""
1616
加载样本文件列表,全部切词后统计词频,按词频由高到低排序后顺次编号
1717
并存到self.word2id_dict和self.id2word_dict中
@@ -32,6 +32,8 @@ def load_file_list(self, file_list):
3232
sorted_list.sort(reverse=True)
3333
for index, item in enumerate(sorted_list):
3434
word = item[1]
35+
if item[0] < min_freq:
36+
break
3537
self.word2id_dict[word] = self.START_ID + index
3638
self.id2word_dict[self.START_ID + index] = word
3739
return index
@@ -45,7 +47,6 @@ def word2id(self, word):
4547
else:
4648
return None
4749

48-
4950
def id2word(self, id):
5051
id = int(id)
5152
if id in self.id2word_dict:

0 commit comments

Comments
 (0)