@@ -121,13 +121,17 @@ def generate_trainig_data(self):
121121 init_seq ()
122122 xy_data = []
123123 y_data = []
124- for i in range (10 , 30 ,10 ):
124+ for i in range (30 , 40 ,10 ):
125125 # 问句、答句都是16字,所以取32个
126126 start = i * self .max_seq_len * 2
127127 middle = i * self .max_seq_len * 2 + self .max_seq_len
128128 end = (i + 1 )* self .max_seq_len * 2
129129 sequence_xy = seq [start :end ]
130130 sequence_y = seq [middle :end ]
131+ print "right answer"
132+ for w in sequence_y :
133+ (match_word , max_cos ) = vector2word (w )
134+ print match_word
131135 sequence_y = [np .ones (self .word_vec_dim )] + sequence_y
132136 xy_data .append (sequence_xy )
133137 y_data .append (sequence_y )
@@ -179,7 +183,7 @@ def model(self, feed_previous=False):
179183 def train (self ):
180184 trainXY , trainY = self .generate_trainig_data ()
181185 model = self .model (feed_previous = False )
182- model .fit (trainXY , trainY , n_epoch = 100 , snapshot_epoch = False )
186+ model .fit (trainXY , trainY , n_epoch = 1000 , snapshot_epoch = False )
183187 model .save ('./model/model' )
184188 return model
185189
@@ -189,6 +193,16 @@ def load(self):
189193 return model
190194
191195if __name__ == '__main__' :
196+ phrase = sys .argv [1 ]
192197 my_seq2seq = MySeq2Seq (word_vec_dim = word_vec_dim , max_seq_len = max_seq_len )
193- my_seq2seq .train ()
194- #model = my_seq2seq.load()
198+ if phrase == 'train' :
199+ my_seq2seq .train ()
200+ else :
201+ model = my_seq2seq .load ()
202+ trainXY , trainY = my_seq2seq .generate_trainig_data ()
203+ predict = model .predict (trainXY )
204+ for sample in predict :
205+ print "predict answer"
206+ for w in sample [1 :]:
207+ (match_word , max_cos ) = vector2word (w )
208+ print match_word , max_cos
0 commit comments