Skip to content

Commit f744d0e

Browse files
author
Yibing Liu
authored
Merge pull request PaddlePaddle#11 from tianxin1860/develop
release ernie code
2 parents fa75b05 + 354d97a commit f744d0e

22 files changed

+3741
-1
lines changed

ERNIE/README.md

Lines changed: 130 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
## Ernie: **E**nhanced **R**epresentation from k**N**owledge **I**nt**E**gration
32

43
*Ernie* 通过建模海量数据中的词、实体及实体关系,学习真实世界的语义知识。相较于 *Bert* 学习局部语言共现的语义表示,*Ernie* 直接对语义知识进行建模,增强了模型语义表示能力。
@@ -14,3 +13,133 @@
1413
此外, *Ernie* 引入了百科、新闻、论坛回帖等多源中文语料进行训练。
1514

1615
我们在多个公开的中文数据集合上进行了效果验证,*Ernie* 模型相较 *Bert*, 取得了更好的效果。
16+
17+
<table style="margin-left: 30.0px;">
18+
<tbody style="margin-left: 30.0px;">
19+
<tr style="margin-left: 30.0px;">
20+
<th class="confluenceTh"><strong>数据集</strong>
21+
<br></th>
22+
<th style="text-align: center;margin-left: 30.0px;" colspan="2"><strong>xnli</strong></th>
23+
<th style="text-align: center;margin-left: 30.0px;" colspan="2"><strong>lcqmc</strong></th>
24+
<th style="text-align: center;margin-left: 30.0px;" colspan="2"><strong>msra ner</strong></th>
25+
<th style="text-align: center;margin-left: 30.0px;" colspan="2"><strong>chnsenticorp</strong></th>
26+
<th style="text-align: center;margin-left: 30.0px;" colspan="4"><strong>nlpcc-dbqa</strong></th></tr>
27+
<tr style="margin-left: 30.0px;">
28+
<td rowspan="2">
29+
<p>
30+
<strong>评估</strong></p>
31+
<p>
32+
<strong>指标</strong>
33+
<br></p>
34+
</td>
35+
<td style="margin-left: 30px; text-align: center;" colspan="2">
36+
<strong>acc</strong>
37+
<br></td>
38+
<td style="margin-left: 30px; text-align: center;" colspan="2">
39+
<strong>acc</strong>
40+
<br></td>
41+
<td style="margin-left: 30px; text-align: center;" colspan="2">
42+
<strong>f1-score</strong>
43+
<br></td>
44+
<td style="margin-left: 30px; text-align: center;" colspan="2">
45+
<strong>acc</strong>
46+
<strong></strong>
47+
<br></td>
48+
<td style="margin-left: 30px; text-align: center;" colspan="2">
49+
<strong>mrr</strong>
50+
<br></td>
51+
<td style="margin-left: 30px; text-align: center;" colspan="2">
52+
<strong>f1-score</strong>
53+
<br></td>
54+
</tr>
55+
<tr style="margin-left: 30.0px;">
56+
<td colspan="1" style="text-align: center;" width="">
57+
<strong>dev</strong>
58+
<br></td>
59+
<td colspan="1" style="text-align: center;" width="">
60+
<strong>test</strong>
61+
<br></td>
62+
<td colspan="1" style="text-align: center;" width="">
63+
<strong>dev</strong>
64+
<br></td>
65+
<td colspan="1" style="text-align: center;" width="">
66+
<strong>test</strong>
67+
<br></td>
68+
<td colspan="1" style="text-align: center;" width="">
69+
<strong>dev</strong>
70+
<br></td>
71+
<td colspan="1" style="text-align: center;" width="">
72+
<strong>test</strong>
73+
<br></td>
74+
<td colspan="1" style="text-align: center;" width="">
75+
<strong>dev</strong>
76+
<br></td>
77+
<td colspan="1" style="text-align: center;" width="">
78+
<strong>test</strong>
79+
<br></td>
80+
<td colspan="1" style="text-align: center;" width="">
81+
<strong>dev</strong>
82+
<br></td>
83+
<td colspan="1" style="text-align: center;" width="">
84+
<strong>test</strong>
85+
<br></td>
86+
<td colspan="1" style="text-align: center;" width="">
87+
<strong>dev</strong>
88+
<br></td>
89+
<td colspan="1" style="text-align: center;" width="">
90+
<strong>test</strong>
91+
<br></td>
92+
</tr>
93+
<tr style="margin-left: 30.0px;">
94+
<td style="margin-left: 30.0px;">
95+
<strong>Bert
96+
<br></strong></td>
97+
<td style="margin-left: 30px; text-align: center;">78.1</td>
98+
<td style="margin-left: 30px; text-align: center;">77.2</td>
99+
<td style="margin-left: 30px; text-align: center;">88.8</td>
100+
<td style="margin-left: 30px; text-align: center;">87.0</td>
101+
<td style="margin-left: 30px; text-align: center;">94.0
102+
<br></td>
103+
<td style="margin-left: 30px; text-align: center;">
104+
<span>92.6</span></td>
105+
<td style="margin-left: 30px; text-align: center;">94.6</td>
106+
<td style="margin-left: 30px; text-align: center;">94.3</td>
107+
<td style="margin-left: 30px; text-align: center;" colspan="1">94.7</td>
108+
<td style="margin-left: 30px; text-align: center;" colspan="1">94.6</td>
109+
<td style="margin-left: 30px; text-align: center;" colspan="1">80.7</td>
110+
<td style="margin-left: 30px; text-align: center;" colspan="1">80.8</td></tr>
111+
<tr style="margin-left: 30.0px;">
112+
<td style="margin-left: 30.0px;">
113+
<strong>Ernie
114+
<br></strong></td>
115+
<td style="margin-left: 30px; text-align: center;">79.9 <span style="color: red;">(<strong>+1.8</strong>)</span></td>
116+
<td style="margin-left: 30px; text-align: center;">78.4 <span style="color: red;">(<strong>+1.2</strong>)</span></td>
117+
<td style="margin-left: 30px; text-align: center;">89.7 <span style="color: red;">(<strong>+0.9</strong>)</span></td>
118+
<td style="margin-left: 30px; text-align: center;">87.4 <span style="color: red;">(<strong>+0.4</strong>)</span></td>
119+
<td style="margin-left: 30px; text-align: center;">95.0 <span style="color: red;">(<strong>+1.0</strong>)</span></td>
120+
<td style="margin-left: 30px; text-align: center;">93.8 <span style="color: red;">(<strong>+1.2</strong>)</span></td>
121+
<td style="margin-left: 30px; text-align: center;">95.2 <span style="color: red;">(<strong>+0.6</strong>)</span></td>
122+
<td style="margin-left: 30px; text-align: center;">95.4 <span style="color: red;">(<strong>+1.1</strong>)</span></td>
123+
<td style="margin-left: 30px; text-align: center;" colspan="1">95.0 <span style="color: red;">(<strong>+0.3</strong>)</span></td>
124+
<td style="margin-left: 30px; text-align: center;" colspan="1">95.1 <span style="color: red;">(<strong>+0.5</strong>)</span></td>
125+
<td style="margin-left: 30px; text-align: center;" colspan="1">82.3 <span style="color: red;">(<strong>+1.6</strong>)</span></td>
126+
<td style="margin-left: 30px; text-align: center;" colspan="1">82.7 <span style="color: red;">(<strong>+1.9</strong>)</span></td></tr>
127+
</tbody>
128+
</table>
129+
130+
#### 数据集介绍
131+
132+
- **自然语言推断任务** XNLI
133+
XNLI 由 Facebook 和纽约大学的研究者联合构建,旨在评测模型多语言的句子理解能力。目标是判断两个句子的关系(矛盾、中立、蕴含)。[链接](https://github.com/facebookresearch/XNLI)
134+
135+
- **语义匹配任务** LCQMC
136+
LCQMC 是哈尔滨工业大学在自然语言处理国际顶会 COLING2018 构建的问答匹配数据集其目,标是判断两个问题的语义是否相同。[链接](http://aclweb.org/anthology/C18-1166)
137+
138+
- **命名实体识别任务** MSRA-NER
139+
MSRA-NER 数据集由微软亚研院发布,其目标是命名实体识别,是指识别文本中具有特定意义的实体,主要包括人名、地名、机构名等。[链接](http://sighan.cs.uchicago.edu/bakeoff2005/)
140+
141+
- **情感分析任务** ChnSentiCorp
142+
ChnSentiCorp 是中文情感分析数据集,其目标是判断一段话的情感态度。
143+
144+
- **检索式问答任务** nlpcc-dbqa
145+
nlpcc-dbqa是由国际自然语言处理和中文计算会议NLPCC于2016年举办的评测任务,其目标是选择能够回答问题的答案。[链接](http://tcci.ccf.org.cn/conference/2016/dldoc/evagline2.pdf)

ERNIE/batching.py

Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Mask, padding and batching."""
15+
16+
from __future__ import absolute_import
17+
from __future__ import division
18+
from __future__ import print_function
19+
20+
import numpy as np
21+
22+
def mask(batch_tokens, seg_labels, mask_word_tags, total_token_num, vocab_size, CLS=1, SEP=2, MASK=3):
23+
"""
24+
Add mask for batch_tokens, return out, mask_label, mask_pos;
25+
Note: mask_pos responding the batch_tokens after padded;
26+
"""
27+
max_len = max([len(sent) for sent in batch_tokens])
28+
mask_label = []
29+
mask_pos = []
30+
prob_mask = np.random.rand(total_token_num)
31+
# Note: the first token is [CLS], so [low=1]
32+
replace_ids = np.random.randint(1, high=vocab_size, size=total_token_num)
33+
pre_sent_len = 0
34+
prob_index = 0
35+
for sent_index, sent in enumerate(batch_tokens):
36+
mask_flag = False
37+
mask_word = mask_word_tags[sent_index]
38+
prob_index += pre_sent_len
39+
if mask_word:
40+
beg = 0
41+
for token_index, token in enumerate(sent):
42+
seg_label = seg_labels[sent_index][token_index]
43+
if seg_label == 1:
44+
continue
45+
if beg == 0:
46+
if seg_label != -1:
47+
beg = token_index
48+
continue
49+
50+
prob = prob_mask[prob_index + beg]
51+
if prob > 0.15:
52+
pass
53+
else:
54+
for index in xrange(beg, token_index):
55+
prob = prob_mask[prob_index + index]
56+
base_prob = 1.0
57+
if index == beg:
58+
base_prob = 0.15
59+
if base_prob * 0.2 < prob <= base_prob:
60+
mask_label.append(sent[index])
61+
sent[index] = MASK
62+
mask_flag = True
63+
mask_pos.append(sent_index * max_len + index)
64+
elif base_prob * 0.1 < prob <= base_prob * 0.2:
65+
mask_label.append(sent[index])
66+
sent[index] = replace_ids[prob_index + index]
67+
mask_flag = True
68+
mask_pos.append(sent_index * max_len + index)
69+
else:
70+
mask_label.append(sent[index])
71+
mask_pos.append(sent_index * max_len + index)
72+
73+
if seg_label == -1:
74+
beg = 0
75+
else:
76+
beg = token_index
77+
else:
78+
for token_index, token in enumerate(sent):
79+
prob = prob_mask[prob_index + token_index]
80+
if prob > 0.15:
81+
continue
82+
elif 0.03 < prob <= 0.15:
83+
# mask
84+
if token != SEP and token != CLS:
85+
mask_label.append(sent[token_index])
86+
sent[token_index] = MASK
87+
mask_flag = True
88+
mask_pos.append(sent_index * max_len + token_index)
89+
elif 0.015 < prob <= 0.03:
90+
# random replace
91+
if token != SEP and token != CLS:
92+
mask_label.append(sent[token_index])
93+
sent[token_index] = replace_ids[prob_index + token_index]
94+
mask_flag = True
95+
mask_pos.append(sent_index * max_len + token_index)
96+
else:
97+
# keep the original token
98+
if token != SEP and token != CLS:
99+
mask_label.append(sent[token_index])
100+
mask_pos.append(sent_index * max_len + token_index)
101+
102+
pre_sent_len = len(sent)
103+
104+
mask_label = np.array(mask_label).astype("int64").reshape([-1, 1])
105+
mask_pos = np.array(mask_pos).astype("int64").reshape([-1, 1])
106+
return batch_tokens, mask_label, mask_pos
107+
108+
109+
def prepare_batch_data(insts,
110+
total_token_num,
111+
voc_size=0,
112+
pad_id=None,
113+
cls_id=None,
114+
sep_id=None,
115+
mask_id=None,
116+
return_attn_bias=True,
117+
return_max_len=True,
118+
return_num_token=False):
119+
120+
batch_src_ids = [inst[0] for inst in insts]
121+
batch_sent_ids = [inst[1] for inst in insts]
122+
batch_pos_ids = [inst[2] for inst in insts]
123+
labels = [inst[3] for inst in insts]
124+
labels = np.array(labels).astype("int64").reshape([-1, 1])
125+
seg_labels = [inst[4] for inst in insts]
126+
mask_word_tags = [inst[5] for inst in insts]
127+
128+
# First step: do mask without padding
129+
assert mask_id >= 0, "[FATAL] mask_id must >= 0"
130+
out, mask_label, mask_pos = mask(
131+
batch_src_ids,
132+
seg_labels,
133+
mask_word_tags,
134+
total_token_num,
135+
vocab_size=voc_size,
136+
CLS=cls_id,
137+
SEP=sep_id,
138+
MASK=mask_id)
139+
140+
# Second step: padding
141+
src_id, next_sent_index, self_attn_bias = pad_batch_data(
142+
out, pad_idx=pad_id, return_next_sent_pos=True, return_attn_bias=True)
143+
pos_id = pad_batch_data(batch_pos_ids, pad_idx=pad_id)
144+
sent_id = pad_batch_data(batch_sent_ids, pad_idx=pad_id)
145+
146+
return_list = [src_id, pos_id, sent_id, self_attn_bias, mask_label, mask_pos, labels, next_sent_index]
147+
148+
return return_list
149+
150+
151+
def pad_batch_data(insts,
152+
pad_idx=0,
153+
return_pos=False,
154+
return_next_sent_pos=False,
155+
return_attn_bias=False,
156+
return_max_len=False,
157+
return_num_token=False):
158+
"""
159+
Pad the instances to the max sequence length in batch, and generate the
160+
corresponding position data and attention bias.
161+
"""
162+
return_list = []
163+
max_len = max(len(inst) for inst in insts)
164+
# Any token included in dict can be used to pad, since the paddings' loss
165+
# will be masked out by weights and make no effect on parameter gradients.
166+
167+
inst_data = np.array(
168+
[inst + list([pad_idx] * (max_len - len(inst))) for inst in insts])
169+
return_list += [inst_data.astype("int64").reshape([-1, max_len, 1])]
170+
171+
# next_sent_pos for extract first token embedding of each sentence
172+
if return_next_sent_pos:
173+
batch_size = inst_data.shape[0]
174+
max_seq_len = inst_data.shape[1]
175+
next_sent_index = np.array(
176+
range(0, batch_size * max_seq_len, max_seq_len)).astype(
177+
"int64").reshape(-1, 1)
178+
return_list += [next_sent_index]
179+
180+
# position data
181+
if return_pos:
182+
inst_pos = np.array([
183+
list(range(0, len(inst))) + [pad_idx] * (max_len - len(inst))
184+
for inst in insts
185+
])
186+
187+
return_list += [inst_pos.astype("int64").reshape([-1, max_len, 1])]
188+
189+
if return_attn_bias:
190+
# This is used to avoid attention on paddings.
191+
slf_attn_bias_data = np.array([[0] * len(inst) + [-1e9] *
192+
(max_len - len(inst)) for inst in insts])
193+
slf_attn_bias_data = np.tile(
194+
slf_attn_bias_data.reshape([-1, 1, max_len]), [1, max_len, 1])
195+
return_list += [slf_attn_bias_data.astype("float32")]
196+
197+
if return_max_len:
198+
return_list += [max_len]
199+
200+
if return_num_token:
201+
num_token = 0
202+
for inst in insts:
203+
num_token += len(inst)
204+
return_list += [num_token]
205+
206+
return return_list if len(return_list) > 1 else return_list[0]
207+
208+
209+
if __name__ == "__main__":
210+
pass

ERNIE/finetune/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)