forked from SamsungSAILMontreal/TinyRecursiveModels
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsparse_embedding.py
More file actions
139 lines (113 loc) · 4.27 KB
/
sparse_embedding.py
File metadata and controls
139 lines (113 loc) · 4.27 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
from typing import Union
import torch
import torch.distributed as dist
from torch import nn
from torch.optim.optimizer import Optimizer, ParamsT
from recursion.models.common import trunc_normal_init_
class CastedSparseEmbedding(nn.Module):
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
batch_size: int,
init_std: float,
cast_to: torch.dtype,
):
super().__init__()
self.cast_to = cast_to
# Real Weights
# Truncated LeCun normal init
self.weights = nn.Buffer(
trunc_normal_init_(torch.empty((num_embeddings, embedding_dim)), std=init_std),
persistent=True,
)
# Local weights and IDs
# Local embeddings, with gradient, not persistent
self.local_weights = nn.Buffer(
torch.zeros(batch_size, embedding_dim, requires_grad=True), persistent=False
)
# Local embedding IDs, not persistent
self.local_ids = nn.Buffer(torch.zeros(batch_size, dtype=torch.int32), persistent=False)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
if not self.training:
# Test mode, no gradient
return self.weights[inputs].to(self.cast_to)
# Training mode, fill puzzle embedding from weights
with torch.no_grad():
self.local_weights.copy_(self.weights[inputs])
self.local_ids.copy_(inputs)
return self.local_weights.to(self.cast_to)
class CastedSparseEmbeddingSignSGD_Distributed(Optimizer):
def __init__(
self,
params: ParamsT,
world_size: int,
lr: Union[float, torch.Tensor] = 1e-3,
weight_decay: float = 1e-2,
):
if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= weight_decay:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
defaults = dict(lr=lr, weight_decay=weight_decay, world_size=world_size)
super().__init__(params, defaults)
@torch.no_grad
def step(self, closure=None): # type: ignore
for group in self.param_groups:
# Find the sparse embedding weights
local_weights_grad = None
local_ids = None
weights = None
assert len(group["params"]) == 3
for p in group["params"]:
if p.requires_grad:
local_weights_grad = p.grad
elif p.ndim == 1:
local_ids = p
elif p.ndim == 2:
weights = p
else:
assert False
assert local_ids is not None
assert weights is not None
# Apply SignSGD
# Adam ≈ SignSGD if gradient is very sparse
if local_weights_grad is not None:
_sparse_emb_signsgd_dist(
local_weights_grad,
local_ids,
weights,
lr=group["lr"],
weight_decay=group["weight_decay"],
world_size=group["world_size"],
)
def _sparse_emb_signsgd_dist(
local_weights_grad: torch.Tensor,
local_ids: torch.Tensor,
weights: torch.Tensor,
lr: float,
weight_decay: float,
world_size: int,
) -> None:
N, D = local_weights_grad.shape
# All-gather
all_weights_grad = local_weights_grad
all_ids = local_ids
if world_size > 1:
all_weights_grad = torch.empty(
(world_size * N, D), dtype=local_weights_grad.dtype, device=local_weights_grad.device
)
all_ids = torch.empty(world_size * N, dtype=local_ids.dtype, device=local_ids.device)
dist.all_gather_into_tensor(all_weights_grad, local_weights_grad)
dist.all_gather_into_tensor(all_ids, local_ids)
# Unique
grad_ids, inv = all_ids.unique(return_inverse=True)
grad = torch.zeros(
(grad_ids.shape[0], D), dtype=all_weights_grad.dtype, device=all_weights_grad.device
)
grad.scatter_add_(0, inv.unsqueeze(-1).expand(-1, D), all_weights_grad)
# SignSGD with decoupled weight decay
p = weights[grad_ids]
p.mul_(1.0 - lr * weight_decay).add_(torch.sign(grad), alpha=-lr)
# Write updated slices back
weights[grad_ids] = p