Skip to content

Commit edbf3f9

Browse files
author
Yibing Liu
authored
Merge pull request PaddlePaddle#12 from tianxin1860/develop
code format
2 parents f744d0e + 78489c4 commit edbf3f9

9 files changed

Lines changed: 199 additions & 129 deletions

File tree

ERNIE/batching.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,15 @@
1919

2020
import numpy as np
2121

22-
def mask(batch_tokens, seg_labels, mask_word_tags, total_token_num, vocab_size, CLS=1, SEP=2, MASK=3):
22+
23+
def mask(batch_tokens,
24+
seg_labels,
25+
mask_word_tags,
26+
total_token_num,
27+
vocab_size,
28+
CLS=1,
29+
SEP=2,
30+
MASK=3):
2331
"""
2432
Add mask for batch_tokens, return out, mask_label, mask_pos;
2533
Note: mask_pos responding the batch_tokens after padded;
@@ -90,7 +98,8 @@ def mask(batch_tokens, seg_labels, mask_word_tags, total_token_num, vocab_size,
9098
# random replace
9199
if token != SEP and token != CLS:
92100
mask_label.append(sent[token_index])
93-
sent[token_index] = replace_ids[prob_index + token_index]
101+
sent[token_index] = replace_ids[prob_index +
102+
token_index]
94103
mask_flag = True
95104
mask_pos.append(sent_index * max_len + token_index)
96105
else:
@@ -143,7 +152,10 @@ def prepare_batch_data(insts,
143152
pos_id = pad_batch_data(batch_pos_ids, pad_idx=pad_id)
144153
sent_id = pad_batch_data(batch_sent_ids, pad_idx=pad_id)
145154

146-
return_list = [src_id, pos_id, sent_id, self_attn_bias, mask_label, mask_pos, labels, next_sent_index]
155+
return_list = [
156+
src_id, pos_id, sent_id, self_attn_bias, mask_label, mask_pos, labels,
157+
next_sent_index
158+
]
147159

148160
return return_list
149161

@@ -207,4 +219,5 @@ def pad_batch_data(insts,
207219

208220

209221
if __name__ == "__main__":
222+
210223
pass

ERNIE/finetune/classifier.py

Lines changed: 42 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -25,22 +25,20 @@
2525
from model.ernie import ErnieModel
2626

2727

28-
def create_model(args,
29-
pyreader_name,
30-
ernie_config,
31-
is_prediction=False):
28+
def create_model(args, pyreader_name, ernie_config, is_prediction=False):
3229
pyreader = fluid.layers.py_reader(
3330
capacity=50,
3431
shapes=[[-1, args.max_seq_len, 1], [-1, args.max_seq_len, 1],
3532
[-1, args.max_seq_len, 1],
36-
[-1, args.max_seq_len, args.max_seq_len], [-1, 1], [-1, 1], [-1, 1]],
33+
[-1, args.max_seq_len, args.max_seq_len], [-1, 1], [-1, 1],
34+
[-1, 1]],
3735
dtypes=['int64', 'int64', 'int64', 'float', 'int64', 'int64', 'int64'],
3836
lod_levels=[0, 0, 0, 0, 0, 0, 0],
3937
name=pyreader_name,
4038
use_double_buffer=True)
4139

42-
(src_ids, sent_ids, pos_ids, self_attn_mask, labels,
43-
next_sent_index, qids) = fluid.layers.read_file(pyreader)
40+
(src_ids, sent_ids, pos_ids, self_attn_mask, labels, next_sent_index,
41+
qids) = fluid.layers.read_file(pyreader)
4442

4543
ernie = ErnieModel(
4644
src_ids=src_ids,
@@ -57,7 +55,7 @@ def create_model(args,
5755
dropout_implementation="upscale_in_train")
5856
logits = fluid.layers.fc(
5957
input=cls_feats,
60-
size=ernie_config["num_labels"],
58+
size=args.num_labels,
6159
param_attr=fluid.ParamAttr(
6260
name="cls_out_w",
6361
initializer=fluid.initializer.TruncatedNormal(scale=0.02)),
@@ -82,18 +80,21 @@ def create_model(args,
8280
num_seqs = fluid.layers.create_tensor(dtype='int64')
8381
accuracy = fluid.layers.accuracy(input=probs, label=labels, total=num_seqs)
8482

85-
graph_vars = {"loss": loss,
86-
"probs": probs,
87-
"accuracy": accuracy,
88-
"labels": labels,
89-
"num_seqs": num_seqs,
90-
"qids": qids}
83+
graph_vars = {
84+
"loss": loss,
85+
"probs": probs,
86+
"accuracy": accuracy,
87+
"labels": labels,
88+
"num_seqs": num_seqs,
89+
"qids": qids
90+
}
9191

9292
for k, v in graph_vars.items():
93-
v.persistable=True
93+
v.persistable = True
9494

9595
return pyreader, graph_vars
9696

97+
9798
def evaluate_mrr(preds):
9899
last_qid = None
99100
total_mrr = 0.0
@@ -114,6 +115,7 @@ def evaluate_mrr(preds):
114115

115116
return total_mrr / qnum
116117

118+
117119
def evaluate_map(preds):
118120
def singe_map(st, en):
119121
total_p = 0.0
@@ -142,17 +144,18 @@ def singe_map(st, en):
142144
total_map += singe_map(st, len(preds))
143145
return total_map / qnum
144146

147+
145148
def evaluate(exe, test_program, test_pyreader, graph_vars, eval_phase):
146-
train_fetch_list = [graph_vars["loss"].name,
147-
graph_vars["accuracy"].name,
148-
graph_vars["num_seqs"].name
149-
]
149+
train_fetch_list = [
150+
graph_vars["loss"].name, graph_vars["accuracy"].name,
151+
graph_vars["num_seqs"].name
152+
]
150153

151154
if eval_phase == "train":
152155
if "learning_rate" in graph_vars:
153156
train_fetch_list.append(graph_vars["learning_rate"].name)
154157
outputs = exe.run(fetch_list=train_fetch_list)
155-
ret = {"loss":np.mean(outputs[0]), "accuracy":np.mean(outputs[1])}
158+
ret = {"loss": np.mean(outputs[0]), "accuracy": np.mean(outputs[1])}
156159
if "learning_rate" in graph_vars:
157160
ret["learning_rate"] = float(outputs[4][0])
158161
return ret
@@ -162,22 +165,21 @@ def evaluate(exe, test_program, test_pyreader, graph_vars, eval_phase):
162165
qids, labels, scores = [], [], []
163166
time_begin = time.time()
164167

165-
fetch_list = [graph_vars["loss"].name,
166-
graph_vars["accuracy"].name,
167-
graph_vars["probs"].name,
168-
graph_vars["labels"].name,
169-
graph_vars["num_seqs"].name,
170-
graph_vars["qids"].name]
168+
fetch_list = [
169+
graph_vars["loss"].name, graph_vars["accuracy"].name,
170+
graph_vars["probs"].name, graph_vars["labels"].name,
171+
graph_vars["num_seqs"].name, graph_vars["qids"].name
172+
]
171173
while True:
172174
try:
173-
np_loss, np_acc, np_probs, np_labels, np_num_seqs, np_qids = exe.run(program=test_program,
174-
fetch_list=fetch_list)
175+
np_loss, np_acc, np_probs, np_labels, np_num_seqs, np_qids = exe.run(
176+
program=test_program, fetch_list=fetch_list)
175177
total_cost += np.sum(np_loss * np_num_seqs)
176178
total_acc += np.sum(np_acc * np_num_seqs)
177179
total_num_seqs += np.sum(np_num_seqs)
178180
labels.extend(np_labels.reshape((-1)).tolist())
179181
qids.extend(np_qids.reshape(-1).tolist())
180-
scores.extend(np_probs[:,1].reshape(-1).tolist())
182+
scores.extend(np_probs[:, 1].reshape(-1).tolist())
181183
np_preds = np.argmax(np_probs, axis=1).astype(np.float32)
182184
total_label_pos_num += np.sum(np_labels)
183185
total_pred_pos_num += np.sum(np_preds)
@@ -188,20 +190,23 @@ def evaluate(exe, test_program, test_pyreader, graph_vars, eval_phase):
188190
time_end = time.time()
189191

190192
if len(qids) == 0:
191-
print("[%s evaluation] ave loss: %f, ave acc: %f, data_num: %d, elapsed time: %f s" %
192-
(eval_phase, total_cost / total_num_seqs,
193-
total_acc / total_num_seqs, total_num_seqs, time_end - time_begin))
193+
print(
194+
"[%s evaluation] ave loss: %f, ave acc: %f, data_num: %d, elapsed time: %f s"
195+
% (eval_phase, total_cost / total_num_seqs, total_acc /
196+
total_num_seqs, total_num_seqs, time_end - time_begin))
194197
else:
195198
r = total_correct_num / total_label_pos_num
196199
p = total_correct_num / total_pred_pos_num
197200
f = 2 * p * r / (p + r)
198201

199202
assert len(qids) == len(labels) == len(scores)
200-
preds = sorted(zip(qids, scores, labels), key=lambda elem:(elem[0], -elem[1]))
203+
preds = sorted(
204+
zip(qids, scores, labels), key=lambda elem: (elem[0], -elem[1]))
201205
mrr = evaluate_mrr(preds)
202206
map = evaluate_map(preds)
203207

204-
print("[%s evaluation] ave loss: %f, ave_acc: %f, mrr: %f, map: %f, p: %f, r: %f, f1: %f, data_num: %d, elapsed time: %f s" %
205-
(eval_phase, total_cost / total_num_seqs,
206-
total_acc / total_num_seqs,
207-
mrr, map, p, r, f, total_num_seqs, time_end - time_begin))
208+
print(
209+
"[%s evaluation] ave loss: %f, ave_acc: %f, mrr: %f, map: %f, p: %f, r: %f, f1: %f, data_num: %d, elapsed time: %f s"
210+
% (eval_phase, total_cost / total_num_seqs,
211+
total_acc / total_num_seqs, mrr, map, p, r, f, total_num_seqs,
212+
time_end - time_begin))

ERNIE/finetune_args.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@
6464
"Whether to lower case the input text. Should be True for uncased models and False for cased models.")
6565
data_g.add_arg("random_seed", int, 0, "Random seed.")
6666
data_g.add_arg("label_map_config", str, None, "label_map_path.")
67-
data_g.add_arg("num_labels", int, 2, "label number")
67+
data_g.add_arg("num_labels", int, 2, "label number")
6868

6969
run_type_g = ArgumentGroup(parser, "run_type", "running type options.")
7070
run_type_g.add_arg("use_cuda", bool, True, "If set, use GPU for training.")
@@ -74,3 +74,4 @@
7474
run_type_g.add_arg("do_val", bool, True, "Whether to perform evaluation on dev data set.")
7575
run_type_g.add_arg("do_test", bool, True, "Whether to perform evaluation on test data set.")
7676
run_type_g.add_arg("metrics", bool, True, "Whether to perform evaluation on test data set.")
77+
# yapf: enable

ERNIE/pretrain_args.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424

2525
# yapf: disable
2626
parser = argparse.ArgumentParser(__doc__)
27-
parser = argparse.ArgumentParser(__doc__)
2827
model_g = ArgumentGroup(parser, "model", "model configuration and paths.")
2928
model_g.add_arg("ernie_config_path", str, "./config/ernie_config.json", "Path to the json file for ernie model config.")
3029
model_g.add_arg("init_checkpoint", str, None, "Init checkpoint to resume training from.")

ERNIE/reader/pretraining.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
from batching import prepare_batch_data
3232

33+
3334
class ErnieDataReader(object):
3435
def __init__(self,
3536
filelist,
@@ -81,8 +82,8 @@ def parse_line(self, line, max_seq_len=512):
8182
sent_ids = [int(token) for token in sent_ids.split(" ")]
8283
pos_ids = [int(token) for token in pos_ids.split(" ")]
8384
seg_labels = [int(seg_label) for seg_label in seg_labels.split(" ")]
84-
assert len(token_ids) == len(sent_ids) == len(
85-
pos_ids) == len(seg_labels
85+
assert len(token_ids) == len(sent_ids) == len(pos_ids) == len(
86+
seg_labels
8687
), "[Must be true]len(token_ids) == len(sent_ids) == len(pos_ids) == len(seg_labels)"
8788
label = int(label)
8889
if len(token_ids) > max_seq_len:
@@ -153,14 +154,17 @@ def split_sent(sample, max_len, sep_id):
153154
if left_len <= max_len:
154155
return (token_seq[1:sep_index], seg_labels[1:sep_index])
155156
else:
156-
return [token_seq[sep_index + 1: -1], seg_labels[sep_index + 1 : -1]]
157+
return [
158+
token_seq[sep_index + 1:-1], seg_labels[sep_index + 1:-1]
159+
]
157160

158161
for i in range(num_sample):
159162
pair_index = (i + 1) % num_sample
160-
left_tokens, left_seg_labels = split_sent(pos_samples[i],
161-
(self.max_seq_len - 3) // 2, self.sep_id)
162-
right_tokens, right_seg_labels = split_sent(pos_samples[pair_index],
163-
self.max_seq_len - 3 - len(left_tokens), self.sep_id)
163+
left_tokens, left_seg_labels = split_sent(
164+
pos_samples[i], (self.max_seq_len - 3) // 2, self.sep_id)
165+
right_tokens, right_seg_labels = split_sent(
166+
pos_samples[pair_index],
167+
self.max_seq_len - 3 - len(left_tokens), self.sep_id)
164168

165169
token_seq = [self.cls_id] + left_tokens + [self.sep_id] + \
166170
right_tokens + [self.sep_id]

0 commit comments

Comments
 (0)