Skip to content

Commit df1cb67

Browse files
author
Yibing Liu
authored
Merge pull request PaddlePaddle#5 from PaddlePaddle/fix_train_args
Enable batching not in tokens in pretraining
2 parents 8a0753a + e65ba41 commit df1cb67

File tree

2 files changed

+33
-16
lines changed

2 files changed

+33
-16
lines changed

BERT/reader/pretraining.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def __init__(self,
3636
data_dir,
3737
vocab_path,
3838
batch_size=4096,
39+
in_tokens=True,
3940
max_seq_len=512,
4041
shuffle_files=True,
4142
epoch=100,
@@ -46,6 +47,7 @@ def __init__(self,
4647
self.vocab = self.load_vocab(vocab_path)
4748
self.data_dir = data_dir
4849
self.batch_size = batch_size
50+
self.in_tokens = in_tokens
4951
self.shuffle_files = shuffle_files
5052
self.epoch = epoch
5153
self.current_epoch = 0
@@ -60,8 +62,9 @@ def __init__(self,
6062
self.mask_id = self.vocab["[MASK]"]
6163
self.is_test = is_test
6264
self.generate_neg_sample = generate_neg_sample
63-
assert self.batch_size > 100, "Current batch size means total token's number, \
64-
it should not be set to too small number."
65+
if self.in_tokens:
66+
assert self.batch_size >= self.max_seq_len, "The number of " \
67+
"tokens in batch should not be smaller than max seq length."
6568

6669
if self.is_test:
6770
self.epoch = 1
@@ -245,12 +248,16 @@ def reader():
245248
continue
246249
yield sample
247250

248-
def batch_reader(reader, batch_size):
251+
def batch_reader(reader, batch_size, in_tokens):
249252
batch, total_token_num, max_len = [], 0, 0
250253
for parsed_line in reader():
251254
token_ids, sent_ids, pos_ids, label = parsed_line
252255
max_len = max(max_len, len(token_ids))
253-
if (len(batch) + 1) * max_len <= batch_size:
256+
if in_tokens:
257+
to_append = (len(batch) + 1) * max_len <= batch_size
258+
else:
259+
to_append = len(batch) < batch_size
260+
if to_append:
254261
batch.append(parsed_line)
255262
total_token_num += len(token_ids)
256263
else:
@@ -261,8 +268,8 @@ def batch_reader(reader, batch_size):
261268
if len(batch) > 0:
262269
yield batch, total_token_num
263270

264-
for batch_data, total_token_num in batch_reader(reader,
265-
self.batch_size):
271+
for batch_data, total_token_num in batch_reader(
272+
reader, self.batch_size, self.in_tokens):
266273
yield prepare_batch_data(
267274
batch_data,
268275
total_token_num,

BERT/train.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -61,14 +61,15 @@
6161

6262
data_g = ArgumentGroup(parser, "data", "Data paths, vocab paths and data processing options")
6363
data_g.add_arg("data_dir", str, "./data/train/", "Path to training data.")
64-
data_g.add_arg("validation_set_dir", str, "./data/validation/", "Path to training data.")
65-
data_g.add_arg("test_set_dir", str, None, "Path to training data.")
64+
data_g.add_arg("validation_set_dir", str, "./data/validation/", "Path to validation data.")
65+
data_g.add_arg("test_set_dir", str, None, "Path to test data.")
6666
data_g.add_arg("vocab_path", str, "./config/vocab.txt", "Vocabulary path.")
67-
data_g.add_arg("max_seq_len", int, 512, "Number of words of the longest seqence.")
68-
data_g.add_arg("batch_size", int, 16, "Total examples' number in batch for training. see also --in_tokens.")
69-
data_g.add_arg("in_tokens", bool, False,
70-
"If set, the batch size will be the maximum number of tokens in one batch. "
71-
"Otherwise, it will be the maximum number of examples in one batch.")
67+
data_g.add_arg("max_seq_len", int, 512, "Tokens' number of the longest seqence allowed.")
68+
data_g.add_arg("batch_size", int, 8192,
69+
"The total number of examples in one batch for training, see also --in_tokens.")
70+
data_g.add_arg("in_tokens", bool, True,
71+
"If set, the batch size will be the maximum number of tokens in one batch. "
72+
"Otherwise, it will be the maximum number of examples in one batch.")
7273

7374
run_type_g = ArgumentGroup(parser, "run_type", "running type options.")
7475
run_type_g.add_arg("is_distributed", bool, False, "If set, then start distributed training.")
@@ -128,6 +129,7 @@ def predict_wrapper(args,
128129
data_path,
129130
vocab_path=args.vocab_path,
130131
batch_size=args.batch_size,
132+
in_tokens=args.in_tokens,
131133
voc_size=bert_config['vocab_size'],
132134
shuffle_files=False,
133135
epoch=1,
@@ -250,9 +252,16 @@ def train(args):
250252
dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
251253

252254
print("Device count %d" % dev_count)
253-
print("theoretical memory usage: ")
254-
print(fluid.contrib.memory_usage(
255-
program=train_program, batch_size=args.batch_size // args.max_seq_len))
255+
if args.verbose:
256+
if args.in_tokens:
257+
lower_mem, upper_mem, unit = fluid.contrib.memory_usage(
258+
program=train_program,
259+
batch_size=args.batch_size // args.max_seq_len)
260+
else:
261+
lower_mem, upper_mem, unit = fluid.contrib.memory_usage(
262+
program=train_program, batch_size=args.batch_size)
263+
print("Theoretical memory usage in training: %.3f - %.3f %s" %
264+
(lower_mem, upper_mem, unit))
256265

257266
nccl2_num_trainers = 1
258267
nccl2_trainer_id = 0
@@ -293,6 +302,7 @@ def train(args):
293302
data_reader = DataReader(
294303
data_dir=args.data_dir,
295304
batch_size=args.batch_size,
305+
in_tokens=args.in_tokens,
296306
vocab_path=args.vocab_path,
297307
voc_size=bert_config['vocab_size'],
298308
epoch=args.epoch,

0 commit comments

Comments
 (0)