Skip to content

Commit d8afc4d

Browse files
author
tianxin04
committed
release code
1 parent 87f7316 commit d8afc4d

21 files changed

Lines changed: 3611 additions & 0 deletions

ERNIE/batching.py

Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Mask, padding and batching."""
15+
16+
from __future__ import absolute_import
17+
from __future__ import division
18+
from __future__ import print_function
19+
20+
import numpy as np
21+
22+
def mask(batch_tokens, seg_labels, mask_word_tags, total_token_num, vocab_size, CLS=1, SEP=2, MASK=3):
23+
"""
24+
Add mask for batch_tokens, return out, mask_label, mask_pos;
25+
Note: mask_pos responding the batch_tokens after padded;
26+
"""
27+
max_len = max([len(sent) for sent in batch_tokens])
28+
mask_label = []
29+
mask_pos = []
30+
prob_mask = np.random.rand(total_token_num)
31+
# Note: the first token is [CLS], so [low=1]
32+
replace_ids = np.random.randint(1, high=vocab_size, size=total_token_num)
33+
pre_sent_len = 0
34+
prob_index = 0
35+
for sent_index, sent in enumerate(batch_tokens):
36+
mask_flag = False
37+
mask_word = mask_word_tags[sent_index]
38+
prob_index += pre_sent_len
39+
if mask_word:
40+
beg = 0
41+
for token_index, token in enumerate(sent):
42+
seg_label = seg_labels[sent_index][token_index]
43+
if seg_label == 1:
44+
continue
45+
if beg == 0:
46+
if seg_label != -1:
47+
beg = token_index
48+
continue
49+
50+
prob = prob_mask[prob_index + beg]
51+
if prob > 0.15:
52+
pass
53+
else:
54+
for index in xrange(beg, token_index):
55+
prob = prob_mask[prob_index + index]
56+
base_prob = 1.0
57+
if index == beg:
58+
base_prob = 0.15
59+
if base_prob * 0.2 < prob <= base_prob:
60+
mask_label.append(sent[index])
61+
sent[index] = MASK
62+
mask_flag = True
63+
mask_pos.append(sent_index * max_len + index)
64+
elif base_prob * 0.1 < prob <= base_prob * 0.2:
65+
mask_label.append(sent[index])
66+
sent[index] = replace_ids[prob_index + index]
67+
mask_flag = True
68+
mask_pos.append(sent_index * max_len + index)
69+
else:
70+
mask_label.append(sent[index])
71+
mask_pos.append(sent_index * max_len + index)
72+
73+
if seg_label == -1:
74+
beg = 0
75+
else:
76+
beg = token_index
77+
else:
78+
for token_index, token in enumerate(sent):
79+
prob = prob_mask[prob_index + token_index]
80+
if prob > 0.15:
81+
continue
82+
elif 0.03 < prob <= 0.15:
83+
# mask
84+
if token != SEP and token != CLS:
85+
mask_label.append(sent[token_index])
86+
sent[token_index] = MASK
87+
mask_flag = True
88+
mask_pos.append(sent_index * max_len + token_index)
89+
elif 0.015 < prob <= 0.03:
90+
# random replace
91+
if token != SEP and token != CLS:
92+
mask_label.append(sent[token_index])
93+
sent[token_index] = replace_ids[prob_index + token_index]
94+
mask_flag = True
95+
mask_pos.append(sent_index * max_len + token_index)
96+
else:
97+
# keep the original token
98+
if token != SEP and token != CLS:
99+
mask_label.append(sent[token_index])
100+
mask_pos.append(sent_index * max_len + token_index)
101+
102+
pre_sent_len = len(sent)
103+
104+
mask_label = np.array(mask_label).astype("int64").reshape([-1, 1])
105+
mask_pos = np.array(mask_pos).astype("int64").reshape([-1, 1])
106+
return batch_tokens, mask_label, mask_pos
107+
108+
109+
def prepare_batch_data(insts,
110+
total_token_num,
111+
voc_size=0,
112+
pad_id=None,
113+
cls_id=None,
114+
sep_id=None,
115+
mask_id=None,
116+
return_attn_bias=True,
117+
return_max_len=True,
118+
return_num_token=False):
119+
120+
batch_src_ids = [inst[0] for inst in insts]
121+
batch_sent_ids = [inst[1] for inst in insts]
122+
batch_pos_ids = [inst[2] for inst in insts]
123+
labels = [inst[3] for inst in insts]
124+
labels = np.array(labels).astype("int64").reshape([-1, 1])
125+
seg_labels = [inst[4] for inst in insts]
126+
mask_word_tags = [inst[5] for inst in insts]
127+
128+
# First step: do mask without padding
129+
assert mask_id >= 0, "[FATAL] mask_id must >= 0"
130+
out, mask_label, mask_pos = mask(
131+
batch_src_ids,
132+
seg_labels,
133+
mask_word_tags,
134+
total_token_num,
135+
vocab_size=voc_size,
136+
CLS=cls_id,
137+
SEP=sep_id,
138+
MASK=mask_id)
139+
140+
# Second step: padding
141+
src_id, next_sent_index, self_attn_bias = pad_batch_data(
142+
out, pad_idx=pad_id, return_next_sent_pos=True, return_attn_bias=True)
143+
pos_id = pad_batch_data(batch_pos_ids, pad_idx=pad_id)
144+
sent_id = pad_batch_data(batch_sent_ids, pad_idx=pad_id)
145+
146+
return_list = [src_id, pos_id, sent_id, self_attn_bias, mask_label, mask_pos, labels, next_sent_index]
147+
148+
return return_list
149+
150+
151+
def pad_batch_data(insts,
152+
pad_idx=0,
153+
return_pos=False,
154+
return_next_sent_pos=False,
155+
return_attn_bias=False,
156+
return_max_len=False,
157+
return_num_token=False):
158+
"""
159+
Pad the instances to the max sequence length in batch, and generate the
160+
corresponding position data and attention bias.
161+
"""
162+
return_list = []
163+
max_len = max(len(inst) for inst in insts)
164+
# Any token included in dict can be used to pad, since the paddings' loss
165+
# will be masked out by weights and make no effect on parameter gradients.
166+
167+
inst_data = np.array(
168+
[inst + list([pad_idx] * (max_len - len(inst))) for inst in insts])
169+
return_list += [inst_data.astype("int64").reshape([-1, max_len, 1])]
170+
171+
# next_sent_pos for extract first token embedding of each sentence
172+
if return_next_sent_pos:
173+
batch_size = inst_data.shape[0]
174+
max_seq_len = inst_data.shape[1]
175+
next_sent_index = np.array(
176+
range(0, batch_size * max_seq_len, max_seq_len)).astype(
177+
"int64").reshape(-1, 1)
178+
return_list += [next_sent_index]
179+
180+
# position data
181+
if return_pos:
182+
inst_pos = np.array([
183+
list(range(0, len(inst))) + [pad_idx] * (max_len - len(inst))
184+
for inst in insts
185+
])
186+
187+
return_list += [inst_pos.astype("int64").reshape([-1, max_len, 1])]
188+
189+
if return_attn_bias:
190+
# This is used to avoid attention on paddings.
191+
slf_attn_bias_data = np.array([[0] * len(inst) + [-1e9] *
192+
(max_len - len(inst)) for inst in insts])
193+
slf_attn_bias_data = np.tile(
194+
slf_attn_bias_data.reshape([-1, 1, max_len]), [1, max_len, 1])
195+
return_list += [slf_attn_bias_data.astype("float32")]
196+
197+
if return_max_len:
198+
return_list += [max_len]
199+
200+
if return_num_token:
201+
num_token = 0
202+
for inst in insts:
203+
num_token += len(inst)
204+
return_list += [num_token]
205+
206+
return return_list if len(return_list) > 1 else return_list[0]
207+
208+
209+
if __name__ == "__main__":
210+
pass

ERNIE/finetune/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)