2525from model .ernie import ErnieModel
2626
2727
28- def create_model (args ,
29- pyreader_name ,
30- ernie_config ,
31- is_prediction = False ):
28+ def create_model (args , pyreader_name , ernie_config , is_prediction = False ):
3229 pyreader = fluid .layers .py_reader (
3330 capacity = 50 ,
3431 shapes = [[- 1 , args .max_seq_len , 1 ], [- 1 , args .max_seq_len , 1 ],
3532 [- 1 , args .max_seq_len , 1 ],
36- [- 1 , args .max_seq_len , args .max_seq_len ], [- 1 , 1 ], [- 1 , 1 ], [- 1 , 1 ]],
33+ [- 1 , args .max_seq_len , args .max_seq_len ], [- 1 , 1 ], [- 1 , 1 ],
34+ [- 1 , 1 ]],
3735 dtypes = ['int64' , 'int64' , 'int64' , 'float' , 'int64' , 'int64' , 'int64' ],
3836 lod_levels = [0 , 0 , 0 , 0 , 0 , 0 , 0 ],
3937 name = pyreader_name ,
4038 use_double_buffer = True )
4139
42- (src_ids , sent_ids , pos_ids , self_attn_mask , labels ,
43- next_sent_index , qids ) = fluid .layers .read_file (pyreader )
40+ (src_ids , sent_ids , pos_ids , self_attn_mask , labels , next_sent_index ,
41+ qids ) = fluid .layers .read_file (pyreader )
4442
4543 ernie = ErnieModel (
4644 src_ids = src_ids ,
@@ -57,7 +55,7 @@ def create_model(args,
5755 dropout_implementation = "upscale_in_train" )
5856 logits = fluid .layers .fc (
5957 input = cls_feats ,
60- size = ernie_config [ " num_labels" ] ,
58+ size = args . num_labels ,
6159 param_attr = fluid .ParamAttr (
6260 name = "cls_out_w" ,
6361 initializer = fluid .initializer .TruncatedNormal (scale = 0.02 )),
@@ -82,18 +80,21 @@ def create_model(args,
8280 num_seqs = fluid .layers .create_tensor (dtype = 'int64' )
8381 accuracy = fluid .layers .accuracy (input = probs , label = labels , total = num_seqs )
8482
85- graph_vars = {"loss" : loss ,
86- "probs" : probs ,
87- "accuracy" : accuracy ,
88- "labels" : labels ,
89- "num_seqs" : num_seqs ,
90- "qids" : qids }
83+ graph_vars = {
84+ "loss" : loss ,
85+ "probs" : probs ,
86+ "accuracy" : accuracy ,
87+ "labels" : labels ,
88+ "num_seqs" : num_seqs ,
89+ "qids" : qids
90+ }
9191
9292 for k , v in graph_vars .items ():
93- v .persistable = True
93+ v .persistable = True
9494
9595 return pyreader , graph_vars
9696
97+
9798def evaluate_mrr (preds ):
9899 last_qid = None
99100 total_mrr = 0.0
@@ -114,6 +115,7 @@ def evaluate_mrr(preds):
114115
115116 return total_mrr / qnum
116117
118+
117119def evaluate_map (preds ):
118120 def singe_map (st , en ):
119121 total_p = 0.0
@@ -142,17 +144,18 @@ def singe_map(st, en):
142144 total_map += singe_map (st , len (preds ))
143145 return total_map / qnum
144146
147+
145148def evaluate (exe , test_program , test_pyreader , graph_vars , eval_phase ):
146- train_fetch_list = [graph_vars [ "loss" ]. name ,
147- graph_vars ["accuracy" ].name ,
148- graph_vars ["num_seqs" ].name
149- ]
149+ train_fetch_list = [
150+ graph_vars [ "loss" ]. name , graph_vars ["accuracy" ].name ,
151+ graph_vars ["num_seqs" ].name
152+ ]
150153
151154 if eval_phase == "train" :
152155 if "learning_rate" in graph_vars :
153156 train_fetch_list .append (graph_vars ["learning_rate" ].name )
154157 outputs = exe .run (fetch_list = train_fetch_list )
155- ret = {"loss" :np .mean (outputs [0 ]), "accuracy" :np .mean (outputs [1 ])}
158+ ret = {"loss" : np .mean (outputs [0 ]), "accuracy" : np .mean (outputs [1 ])}
156159 if "learning_rate" in graph_vars :
157160 ret ["learning_rate" ] = float (outputs [4 ][0 ])
158161 return ret
@@ -162,22 +165,21 @@ def evaluate(exe, test_program, test_pyreader, graph_vars, eval_phase):
162165 qids , labels , scores = [], [], []
163166 time_begin = time .time ()
164167
165- fetch_list = [graph_vars ["loss" ].name ,
166- graph_vars ["accuracy" ].name ,
167- graph_vars ["probs" ].name ,
168- graph_vars ["labels" ].name ,
169- graph_vars ["num_seqs" ].name ,
170- graph_vars ["qids" ].name ]
168+ fetch_list = [
169+ graph_vars ["loss" ].name , graph_vars ["accuracy" ].name ,
170+ graph_vars ["probs" ].name , graph_vars ["labels" ].name ,
171+ graph_vars ["num_seqs" ].name , graph_vars ["qids" ].name
172+ ]
171173 while True :
172174 try :
173- np_loss , np_acc , np_probs , np_labels , np_num_seqs , np_qids = exe .run (program = test_program ,
174- fetch_list = fetch_list )
175+ np_loss , np_acc , np_probs , np_labels , np_num_seqs , np_qids = exe .run (
176+ program = test_program , fetch_list = fetch_list )
175177 total_cost += np .sum (np_loss * np_num_seqs )
176178 total_acc += np .sum (np_acc * np_num_seqs )
177179 total_num_seqs += np .sum (np_num_seqs )
178180 labels .extend (np_labels .reshape ((- 1 )).tolist ())
179181 qids .extend (np_qids .reshape (- 1 ).tolist ())
180- scores .extend (np_probs [:,1 ].reshape (- 1 ).tolist ())
182+ scores .extend (np_probs [:, 1 ].reshape (- 1 ).tolist ())
181183 np_preds = np .argmax (np_probs , axis = 1 ).astype (np .float32 )
182184 total_label_pos_num += np .sum (np_labels )
183185 total_pred_pos_num += np .sum (np_preds )
@@ -188,20 +190,23 @@ def evaluate(exe, test_program, test_pyreader, graph_vars, eval_phase):
188190 time_end = time .time ()
189191
190192 if len (qids ) == 0 :
191- print ("[%s evaluation] ave loss: %f, ave acc: %f, data_num: %d, elapsed time: %f s" %
192- (eval_phase , total_cost / total_num_seqs ,
193- total_acc / total_num_seqs , total_num_seqs , time_end - time_begin ))
193+ print (
194+ "[%s evaluation] ave loss: %f, ave acc: %f, data_num: %d, elapsed time: %f s"
195+ % (eval_phase , total_cost / total_num_seqs , total_acc /
196+ total_num_seqs , total_num_seqs , time_end - time_begin ))
194197 else :
195198 r = total_correct_num / total_label_pos_num
196199 p = total_correct_num / total_pred_pos_num
197200 f = 2 * p * r / (p + r )
198201
199202 assert len (qids ) == len (labels ) == len (scores )
200- preds = sorted (zip (qids , scores , labels ), key = lambda elem :(elem [0 ], - elem [1 ]))
203+ preds = sorted (
204+ zip (qids , scores , labels ), key = lambda elem : (elem [0 ], - elem [1 ]))
201205 mrr = evaluate_mrr (preds )
202206 map = evaluate_map (preds )
203207
204- print ("[%s evaluation] ave loss: %f, ave_acc: %f, mrr: %f, map: %f, p: %f, r: %f, f1: %f, data_num: %d, elapsed time: %f s" %
205- (eval_phase , total_cost / total_num_seqs ,
206- total_acc / total_num_seqs ,
207- mrr , map , p , r , f , total_num_seqs , time_end - time_begin ))
208+ print (
209+ "[%s evaluation] ave loss: %f, ave_acc: %f, mrr: %f, map: %f, p: %f, r: %f, f1: %f, data_num: %d, elapsed time: %f s"
210+ % (eval_phase , total_cost / total_num_seqs ,
211+ total_acc / total_num_seqs , mrr , map , p , r , f , total_num_seqs ,
212+ time_end - time_begin ))
0 commit comments