1616from adam_atan2_pytorch import AdamAtan2
1717from omegaconf import DictConfig
1818from torch import nn
19+ from torch .nn .functional import cosine_similarity
1920from torch .utils .data import DataLoader
2021from torchjd .aggregation import UPGrad
2122from torchjd .autojac ._transform import Accumulate , Aggregate , OrderedSet
2627from recursion .puzzle_dataset import PuzzleDataset , PuzzleDatasetConfig , PuzzleDatasetMetadata
2728from recursion .utils .functions import get_model_source_path , load_model_class
2829
30+ global_step = 0
31+
32+
33+ def print_gramian (_ , inputs , __ ):
34+ if not dist .is_initialized () or dist .get_rank () == 0 :
35+ # print(inputs[0])
36+ diag = torch .diag (inputs [0 ]).sqrt ()
37+ outer = diag .unsqueeze (0 ) * diag .unsqueeze (1 )
38+ scaled_gramian = inputs [0 ].clone () / outer
39+
40+ wandb .log (
41+ {
42+ "gramian_min" : inputs [0 ].min (),
43+ "gramian_mean" : inputs [0 ].mean (),
44+ "gramian_median" : inputs [0 ].median (),
45+ "min_gramian_scaled" : scaled_gramian .min (),
46+ },
47+ step = global_step ,
48+ )
49+
50+
51+ def log_gd_similarity (_ , inputs : tuple [torch .Tensor , ...], aggregation : torch .Tensor ) -> None :
52+ """Prints the cosine similarity between the aggregation and the average gradient."""
53+ if not dist .is_initialized () or dist .get_rank () == 0 :
54+ matrix = inputs [0 ]
55+ gd_output = matrix .mean (dim = 0 )
56+ similarity = cosine_similarity (aggregation , gd_output , dim = 0 )
57+ wandb .log ({"gd_similarity" : similarity .item ()}, step = global_step )
58+
2959
3060class LossConfig (pydantic .BaseModel ):
3161 model_config = pydantic .ConfigDict (extra = "allow" )
@@ -320,7 +350,7 @@ def create_evaluators(config: PretrainConfig, eval_metadata: PuzzleDatasetMetada
320350 return evaluators
321351
322352
323- UPDATE_EVERY = 2
353+ UPDATE_EVERY = 8
324354
325355
326356def train_batch (
@@ -331,7 +361,9 @@ def train_batch(
331361 rank : int ,
332362 world_size : int ,
333363):
364+ global global_step
334365 train_state .step += 1
366+ global_step = train_state .step
335367 if train_state .step > train_state .total_steps : # At most train_total_steps
336368 return
337369
@@ -348,8 +380,8 @@ def train_batch(
348380 carry = train_state .carry , batch = batch , return_keys = []
349381 )
350382
351- current_step = train_state .carry .steps [0 ].item () # Something between 0 and 15
352- if ( current_step + 1 ) % UPDATE_EVERY :
383+ current_step = train_state .carry .steps [0 ].item () # Something between 1 and 16
384+ if current_step > 0 and ( current_step % UPDATE_EVERY == 0 ) :
353385 memories = train_state .carry .inner_carry .memories
354386 memories_wrt = train_state .carry .inner_carry .memories_wrt
355387 rec_model = train_state .model .model .inner .L_level
@@ -375,6 +407,8 @@ def train_batch(
375407 param : torch .stack (gradients , dim = 0 ) for param , gradients in param_to_gradients .items ()
376408 }
377409 aggregator = UPGrad ()
410+ aggregator .weighting .weighting .weighting .register_forward_hook (print_gramian )
411+ aggregator .register_forward_hook (log_gd_similarity )
378412 transform = Accumulate () << Aggregate (aggregator , OrderedSet (list (rec_model .parameters ())))
379413 transform (param_to_jacobian ) # This stores the aggregated Jacobian in the .grad fields
380414
0 commit comments