|
61 | 61 |
|
62 | 62 | data_g = ArgumentGroup(parser, "data", "Data paths, vocab paths and data processing options") |
63 | 63 | data_g.add_arg("data_dir", str, "./data/train/", "Path to training data.") |
64 | | -data_g.add_arg("validation_set_dir", str, "./data/validation/", "Path to training data.") |
65 | | -data_g.add_arg("test_set_dir", str, None, "Path to training data.") |
| 64 | +data_g.add_arg("validation_set_dir", str, "./data/validation/", "Path to validation data.") |
| 65 | +data_g.add_arg("test_set_dir", str, None, "Path to test data.") |
66 | 66 | data_g.add_arg("vocab_path", str, "./config/vocab.txt", "Vocabulary path.") |
67 | | -data_g.add_arg("max_seq_len", int, 512, "Number of words of the longest seqence.") |
68 | | -data_g.add_arg("batch_size", int, 16, "Total examples' number in batch for training. see also --in_tokens.") |
69 | | -data_g.add_arg("in_tokens", bool, False, |
70 | | - "If set, the batch size will be the maximum number of tokens in one batch. " |
71 | | - "Otherwise, it will be the maximum number of examples in one batch.") |
| 67 | +data_g.add_arg("max_seq_len", int, 512, "Tokens' number of the longest seqence allowed.") |
| 68 | +data_g.add_arg("batch_size", int, 8192, |
| 69 | + "The total number of examples in one batch for training, see also --in_tokens.") |
| 70 | +data_g.add_arg("in_tokens", bool, True, |
| 71 | + "If set, the batch size will be the maximum number of tokens in one batch. " |
| 72 | + "Otherwise, it will be the maximum number of examples in one batch.") |
72 | 73 |
|
73 | 74 | run_type_g = ArgumentGroup(parser, "run_type", "running type options.") |
74 | 75 | run_type_g.add_arg("is_distributed", bool, False, "If set, then start distributed training.") |
@@ -128,6 +129,7 @@ def predict_wrapper(args, |
128 | 129 | data_path, |
129 | 130 | vocab_path=args.vocab_path, |
130 | 131 | batch_size=args.batch_size, |
| 132 | + in_tokens=args.in_tokens, |
131 | 133 | voc_size=bert_config['vocab_size'], |
132 | 134 | shuffle_files=False, |
133 | 135 | epoch=1, |
@@ -250,9 +252,16 @@ def train(args): |
250 | 252 | dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count())) |
251 | 253 |
|
252 | 254 | print("Device count %d" % dev_count) |
253 | | - print("theoretical memory usage: ") |
254 | | - print(fluid.contrib.memory_usage( |
255 | | - program=train_program, batch_size=args.batch_size // args.max_seq_len)) |
| 255 | + if args.verbose: |
| 256 | + if args.in_tokens: |
| 257 | + lower_mem, upper_mem, unit = fluid.contrib.memory_usage( |
| 258 | + program=train_program, |
| 259 | + batch_size=args.batch_size // args.max_seq_len) |
| 260 | + else: |
| 261 | + lower_mem, upper_mem, unit = fluid.contrib.memory_usage( |
| 262 | + program=train_program, batch_size=args.batch_size) |
| 263 | + print("Theoretical memory usage in training: %.3f - %.3f %s" % |
| 264 | + (lower_mem, upper_mem, unit)) |
256 | 265 |
|
257 | 266 | nccl2_num_trainers = 1 |
258 | 267 | nccl2_trainer_id = 0 |
@@ -293,6 +302,7 @@ def train(args): |
293 | 302 | data_reader = DataReader( |
294 | 303 | data_dir=args.data_dir, |
295 | 304 | batch_size=args.batch_size, |
| 305 | + in_tokens=args.in_tokens, |
296 | 306 | vocab_path=args.vocab_path, |
297 | 307 | voc_size=bert_config['vocab_size'], |
298 | 308 | epoch=args.epoch, |
|
0 commit comments