forked from deepspeedai/DeepSpeedExamples
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_train_bert.py
More file actions
108 lines (100 loc) · 3.78 KB
/
test_train_bert.py
File metadata and controls
108 lines (100 loc) · 3.78 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import tempfile
import numpy as np
import pytest
import torch
import tqdm
from transformers import AutoTokenizer
from train_bert import create_data_iterator, create_model, load_model_checkpoint, train
@pytest.fixture(scope="function")
def checkpoint_dir() -> str:
with tempfile.TemporaryDirectory() as tmpdirname:
yield tmpdirname
def test_masking_stats(tol: float = 1e-3):
"""Test to check that the masking probabilities
match what we expect them to be.
"""
kwargs = {
"mask_prob": 0.15,
"random_replace_prob": 0.1,
"unmask_replace_prob": 0.1,
"batch_size": 8,
}
tokenizer = AutoTokenizer.from_pretrained("roberta-base")
dataloader = create_data_iterator(**kwargs)
num_samples = 10000
total_tokens = 0
masked_tokens = 0
random_replace_tokens = 0
unmasked_replace_tokens = 0
for ix, batch in tqdm.tqdm(enumerate(dataloader, start=1), total=num_samples):
# Since we don't mask the BOS / EOS tokens, we subtract them from the total tokens
total_tokens += batch["attention_mask"].sum().item() - (
2 * batch["attention_mask"].size(0)
)
masked_tokens += (batch["tgt_tokens"] != tokenizer.pad_token_id).sum().item()
random_or_unmasked = (
batch["tgt_tokens"] != tokenizer.pad_token_id
).logical_and(batch["src_tokens"] != tokenizer.mask_token_id)
unmasked = random_or_unmasked.logical_and(
batch["src_tokens"] == batch["tgt_tokens"]
)
unmasked_replace_tokens += unmasked.sum().item()
random_replace_tokens += random_or_unmasked.sum().item() - unmasked.sum().item()
if ix == num_samples:
break
estimated_mask_prob = masked_tokens / total_tokens
estimated_random_tokens = random_replace_tokens / total_tokens
estimated_unmasked_tokens = unmasked_replace_tokens / total_tokens
assert np.isclose(estimated_mask_prob, kwargs["mask_prob"], atol=tol)
assert np.isclose(
estimated_random_tokens,
kwargs["random_replace_prob"] * kwargs["mask_prob"],
atol=tol,
)
assert np.isclose(
estimated_unmasked_tokens,
kwargs["unmask_replace_prob"] * kwargs["mask_prob"],
atol=tol,
)
def test_model_checkpointing(checkpoint_dir: str):
"""Training a small model, and ensuring
that both checkpointing and resuming from
a checkpoint work.
"""
# First train a tiny model for 5 iterations
train_params = {
"checkpoint_dir": checkpoint_dir,
"checkpoint_every": 2,
"num_layers": 2,
"num_heads": 4,
"ff_dim": 64,
"h_dim": 64,
"num_iterations": 5,
}
exp_dir = train(**train_params)
# now check that we have 3 checkpoints
assert len(list(exp_dir.glob("*.pt"))) == 3
model = create_model(
num_layers=train_params["num_layers"],
num_heads=train_params["num_heads"],
ff_dim=train_params["ff_dim"],
h_dim=train_params["h_dim"],
dropout=0.1,
)
optimizer = torch.optim.Adam(model.parameters())
step, model, optimizer = load_model_checkpoint(exp_dir, model, optimizer)
assert step == 5
model_state_dict = model.state_dict()
# the saved checkpoint would be for iteration 5
correct_state_dict = torch.load(exp_dir / "checkpoint.iter_5.pt")
correct_model_state_dict = correct_state_dict["model"]
assert set(model_state_dict.keys()) == set(correct_model_state_dict.keys())
assert all(
torch.allclose(model_state_dict[key], correct_model_state_dict[key])
for key in model_state_dict.keys()
)
# Finally, try training with the checkpoint
train_params.pop("checkpoint_dir")
train_params["load_checkpoint_dir"] = str(exp_dir)
train_params["num_iterations"] = 10
train(**train_params)