forked from deepspeedai/DeepSpeed
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsimple_model.py
More file actions
74 lines (59 loc) · 2.42 KB
/
simple_model.py
File metadata and controls
74 lines (59 loc) · 2.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import os
import json
import argparse
import torch
class SimpleModel(torch.nn.Module):
def __init__(self, hidden_dim, empty_grad=False, rank=0):
super(SimpleModel, self).__init__()
self.linear = torch.nn.Linear(hidden_dim, hidden_dim)
if empty_grad:
self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim)
self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
self.rank = rank
self.empty_grad = empty_grad
def forward(self, x, y):
hidden_dim = x
if self.rank == 0 and self.empty_grad:
hidden_dim = self.linear(hidden_dim) + self.linear2(hidden_dim)
else:
hidden_dim = self.linear(hidden_dim)
return self.cross_entropy_loss(hidden_dim, y)
class SimpleOptimizer(torch.optim.Optimizer):
def __init__(self, params, lr=0.11072018):
defaults = dict(lr=lr)
super(SimpleOptimizer, self).__init__(params, defaults)
def __setstate__(self, state):
super(SimpleOptimizer, self).__setstate__(state)
def step(self, closure=None):
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
d_p = p.grad.data
p.data.add_(-group['lr'], d_p)
return loss
def random_dataloader(model, total_samples, hidden_dim, device, dtype=torch.half):
batch_size = model.train_micro_batch_size_per_gpu()
train_data = torch.randn(total_samples, hidden_dim, device=device, dtype=dtype)
train_label = torch.empty(total_samples,
dtype=torch.long,
device=device).random_(hidden_dim)
train_dataset = torch.utils.data.TensorDataset(train_data, train_label)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size)
return train_loader
def create_config_from_dict(tmpdir, config_dict):
config_path = os.path.join(tmpdir, 'temp_config.json')
with open(config_path, 'w') as fd:
json.dump(config_dict, fd)
return config_path
def args_from_dict(tmpdir, config_dict):
config_path = create_config_from_dict(tmpdir, config_dict)
parser = argparse.ArgumentParser()
args = parser.parse_args(args='')
args.deepspeed = True
args.deepspeed_config = config_path
args.local_rank = 0
return args