Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Add batched JD training
  • Loading branch information
ValerianRey committed Dec 18, 2025
commit 927b36eb891bbc16208d4609e94db368e64b7aba
8 changes: 4 additions & 4 deletions src/recursion/dataset/repeat_after_k.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
from torch import Tensor


def make_sequence(length: int, k: int) -> tuple[Tensor, Tensor]:
seq = torch.randint(low=0, high=2, size=[length + k])
input = seq[k:]
def make_sequences(length: int, k: int, batch_size: int) -> tuple[Tensor, Tensor]:
seq = torch.randint(low=0, high=2, size=[batch_size, length + k])
input = seq[:, k:]

if k == 0:
target = seq
else:
target = seq[:-k]
target = seq[:, :-k]

return input, target
89 changes: 69 additions & 20 deletions src/recursion/models/trivial_memory_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,20 @@

import torch
from torch import Tensor, nn
from torch.nn.functional import cosine_similarity
from torch.optim import SGD
from torchjd.aggregation import UPGrad
from torchjd.autojac._transform import Accumulate, Aggregate, OrderedSet
from torchjd.autojac._transform import (
Accumulate,
Aggregate,
Diagonalize,
Init,
Jac,
OrderedSet,
Select,
)

from recursion.dataset.repeat_after_k import make_sequence
from recursion.dataset.repeat_after_k import make_sequences


class TrivialMemoryModel(nn.Module):
Expand All @@ -26,51 +35,91 @@ def forward(self, input: Tensor, memory: Tensor) -> tuple[Tensor, Tensor]:
return x


input_sequence, target_sequence = make_sequence(50000, 3)
batch_size = 16
k = 3
input_sequences, target_sequences = make_sequences(50000, k, batch_size=batch_size)

memory_dim = 8

model = TrivialMemoryModel(memory_dim)
head = nn.Linear(memory_dim, 1)
memory = torch.randn(memory_dim)
criterion = nn.BCEWithLogitsLoss()
memory = torch.zeros(batch_size, memory_dim)
criterion = nn.BCEWithLogitsLoss(reduction="none")
optimizer = SGD(model.parameters(), lr=1e-2)
head_optimizer = SGD(head.parameters(), lr=1e-2)
memories = []
memories_wrt = []
param_to_gradients = defaultdict(list)
param_to_jacobians = defaultdict(list)
torch.set_printoptions(linewidth=200)
update_every = 4

aggregator = UPGrad()

for i, (input, target) in enumerate(zip(input_sequence, target_sequence, strict=True)):

def hook(_, args: tuple[Tensor], __) -> None:
jacobian = args[0]
gramian = jacobian @ jacobian.T
print(gramian[0, 0] / gramian[k * batch_size, k * batch_size])


def print_gd_similarity(_, inputs: tuple[torch.Tensor, ...], aggregation: torch.Tensor) -> None:
"""Prints the cosine similarity between the aggregation and the average gradient."""
matrix = inputs[0]
gd_output = matrix.mean(dim=0)
similarity = cosine_similarity(aggregation, gd_output, dim=0)
print(f"Cosine similarity: {similarity.item():.4f}")


aggregator.register_forward_hook(hook)
aggregator.register_forward_hook(print_gd_similarity)

for i, (input, target) in enumerate(zip(input_sequences.T, target_sequences.T, strict=True)):
memories_wrt.append(memory.detach().requires_grad_(True))
memory = model(input.unsqueeze(0).to(dtype=torch.float32), memories_wrt[-1])

memory = model(input.unsqueeze(1).to(dtype=torch.float32), memories_wrt[-1])
output = head(memory)
loss = criterion(output, target.unsqueeze(0).to(dtype=torch.float32))
losses = criterion(output, target.unsqueeze(1).to(dtype=torch.float32))
loss = losses.mean()
memories.append(memory)
transform = Accumulate() << Aggregate(aggregator, OrderedSet(list(model.parameters())))

print(f"{loss.item():.1e}")

if (i + 1) % update_every == 0:
grad_output = torch.autograd.grad(loss, [memories[-1]], retain_graph=True)
# grad_output = torch.autograd.grad(loss, [memories[-1]], retain_graph=True)

ordered_set = OrderedSet(losses)
init = Init(ordered_set)
diag = Diagonalize(ordered_set)
jac = Jac(ordered_set, OrderedSet([memories[-1]]), chunk_size=None, retain_graph=True)
trans = jac << diag << init

trans.check_keys(set())

jac_output = trans({})

for j in range(update_every):
grads = torch.autograd.grad(
memories[-j - 1],
list(model.parameters()) + [memories_wrt[-j - 1]],
grad_outputs=grad_output,
new_jac = Jac(
OrderedSet([memories[-j - 1]]),
OrderedSet(list(model.parameters()) + [memories_wrt[-j - 1]]),
chunk_size=None,
)
grads_wrt_params = grads[:-1]
grad_output = grads[-1]
select_jac_wrt_model = Select(OrderedSet(list(model.parameters())))
select_jac_wrt_memory = Select(OrderedSet([memories_wrt[-j - 1]]))

jacobians = new_jac(jac_output)
jac_output = select_jac_wrt_memory(jacobians)

if j < update_every - 1:
jac_output = {memories[-j - 2]: jac_output[memories_wrt[-j - 1]]}

jac_wrt_params = select_jac_wrt_model(jacobians)

for param, grad in zip(model.parameters(), grads_wrt_params, strict=True):
param_to_gradients[param].append(grad)
for param, jacob in jac_wrt_params.items():
param_to_jacobians[param].append(jacob)

param_to_jacobian = {
param: torch.stack(gradients, dim=0) for param, gradients in param_to_gradients.items()
param: torch.cat(jacobs, dim=0) for param, jacobs in param_to_jacobians.items()
}

optimizer.zero_grad()
Expand All @@ -79,7 +128,7 @@ def forward(self, input: Tensor, memory: Tensor) -> tuple[Tensor, Tensor]:

memories = []
memories_wrt = []
param_to_gradients = defaultdict(list)
param_to_jacobians = defaultdict(list)

head_optimizer.zero_grad()
torch.autograd.backward(loss, inputs=list(head.parameters()))
Expand Down