Skip to content

Commit 37fe9eb

Browse files
committed
feat: torchJD working version
1 parent 9f447a3 commit 37fe9eb

File tree

3 files changed

+39
-4
lines changed

3 files changed

+39
-4
lines changed

.vscode/launch.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
"epochs=50000",
1919
"eval_interval=5000",
2020
"lr=1e-4",
21-
"global_batch_size=768",
21+
"global_batch_size=1",
2222
"puzzle_emb_lr=1e-4",
2323
"weight_decay=1.0",
2424
"puzzle_emb_weight_decay=1.0",

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,5 @@ dependencies = [
2626
"numba",
2727
"triton",
2828
"pre-commit",
29+
"torchjd>=0.8.0",
2930
]

src/recursion/pretrain.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from adam_atan2_pytorch import AdamAtan2
1717
from omegaconf import DictConfig
1818
from torch import nn
19+
from torch.nn.functional import cosine_similarity
1920
from torch.utils.data import DataLoader
2021
from torchjd.aggregation import UPGrad
2122
from torchjd.autojac._transform import Accumulate, Aggregate, OrderedSet
@@ -26,6 +27,35 @@
2627
from recursion.puzzle_dataset import PuzzleDataset, PuzzleDatasetConfig, PuzzleDatasetMetadata
2728
from recursion.utils.functions import get_model_source_path, load_model_class
2829

30+
global_step = 0
31+
32+
33+
def print_gramian(_, inputs, __):
34+
if not dist.is_initialized() or dist.get_rank() == 0:
35+
# print(inputs[0])
36+
diag = torch.diag(inputs[0]).sqrt()
37+
outer = diag.unsqueeze(0) * diag.unsqueeze(1)
38+
scaled_gramian = inputs[0].clone() / outer
39+
40+
wandb.log(
41+
{
42+
"gramian_min": inputs[0].min(),
43+
"gramian_mean": inputs[0].mean(),
44+
"gramian_median": inputs[0].median(),
45+
"min_gramian_scaled": scaled_gramian.min(),
46+
},
47+
step=global_step,
48+
)
49+
50+
51+
def log_gd_similarity(_, inputs: tuple[torch.Tensor, ...], aggregation: torch.Tensor) -> None:
52+
"""Prints the cosine similarity between the aggregation and the average gradient."""
53+
if not dist.is_initialized() or dist.get_rank() == 0:
54+
matrix = inputs[0]
55+
gd_output = matrix.mean(dim=0)
56+
similarity = cosine_similarity(aggregation, gd_output, dim=0)
57+
wandb.log({"gd_similarity": similarity.item()}, step=global_step)
58+
2959

3060
class LossConfig(pydantic.BaseModel):
3161
model_config = pydantic.ConfigDict(extra="allow")
@@ -320,7 +350,7 @@ def create_evaluators(config: PretrainConfig, eval_metadata: PuzzleDatasetMetada
320350
return evaluators
321351

322352

323-
UPDATE_EVERY = 2
353+
UPDATE_EVERY = 8
324354

325355

326356
def train_batch(
@@ -331,7 +361,9 @@ def train_batch(
331361
rank: int,
332362
world_size: int,
333363
):
364+
global global_step
334365
train_state.step += 1
366+
global_step = train_state.step
335367
if train_state.step > train_state.total_steps: # At most train_total_steps
336368
return
337369

@@ -348,8 +380,8 @@ def train_batch(
348380
carry=train_state.carry, batch=batch, return_keys=[]
349381
)
350382

351-
current_step = train_state.carry.steps[0].item() # Something between 0 and 15
352-
if (current_step + 1) % UPDATE_EVERY:
383+
current_step = train_state.carry.steps[0].item() # Something between 1 and 16
384+
if current_step > 0 and (current_step % UPDATE_EVERY == 0):
353385
memories = train_state.carry.inner_carry.memories
354386
memories_wrt = train_state.carry.inner_carry.memories_wrt
355387
rec_model = train_state.model.model.inner.L_level
@@ -375,6 +407,8 @@ def train_batch(
375407
param: torch.stack(gradients, dim=0) for param, gradients in param_to_gradients.items()
376408
}
377409
aggregator = UPGrad()
410+
aggregator.weighting.weighting.weighting.register_forward_hook(print_gramian)
411+
aggregator.register_forward_hook(log_gd_similarity)
378412
transform = Accumulate() << Aggregate(aggregator, OrderedSet(list(rec_model.parameters())))
379413
transform(param_to_jacobian) # This stores the aggregated Jacobian in the .grad fields
380414

0 commit comments

Comments
 (0)