-
Notifications
You must be signed in to change notification settings - Fork 0
feat: [WIP] add RNN training #1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -26,4 +26,6 @@ dependencies = [ | |
| "numba", | ||
| "triton", | ||
| "pre-commit", | ||
| "torchjd", | ||
| "torchviz" | ||
| ] | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,14 @@ | ||
| import torch | ||
| from torch import Tensor | ||
|
|
||
|
|
||
| def make_sequence(length: int, k: int) -> tuple[Tensor, Tensor]: | ||
| seq = torch.randint(low=0, high=2, size=[length + k]) | ||
| input = seq[k:] | ||
|
|
||
| if k == 0: | ||
| target = seq | ||
| else: | ||
| target = seq[:-k] | ||
|
|
||
| return input, target |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,103 @@ | ||
| from collections import defaultdict | ||
|
|
||
| import torch | ||
| from torch import Tensor, nn | ||
| from torch.optim import SGD | ||
|
|
||
| from recursion.dataset.repeat_after_k import make_sequence | ||
|
|
||
|
|
||
| class TrivialMemoryModel(nn.Module): | ||
| def __init__(self, memory_dim: int): | ||
| super().__init__() | ||
|
|
||
| hidden_size = 2 * (1 + memory_dim) | ||
| self.fc1 = nn.Linear(1 + memory_dim, hidden_size) | ||
| self.fc2 = nn.Linear(hidden_size, memory_dim) | ||
| # self.fc3 = nn.Linear(memory_dim, 1) | ||
| self.relu = nn.ReLU() | ||
|
|
||
| def forward(self, input: Tensor, memory: Tensor) -> tuple[Tensor, Tensor]: | ||
| x = torch.cat([input, memory], dim=-1) | ||
| x = self.relu(self.fc1(x)) | ||
| x = self.fc2(x) | ||
|
|
||
| return x | ||
|
|
||
|
|
||
| input_sequence, target_sequence = make_sequence(7, 3) | ||
|
|
||
| memory_dim = 8 | ||
|
|
||
| model = TrivialMemoryModel(memory_dim) | ||
| head = nn.Linear(memory_dim, 1) | ||
| memory = torch.randn(memory_dim) | ||
| criterion = nn.BCEWithLogitsLoss() | ||
| optimizer = SGD(model.parameters(), lr=1e-2) | ||
| memories = [] | ||
| memories_wrt = [] | ||
|
|
||
| param_to_gradients = defaultdict(list) | ||
| torch.set_printoptions(linewidth=200) | ||
| update_every = 6 | ||
|
|
||
| from torchjd.aggregation import UPGradWeighting | ||
|
|
||
| weighting = UPGradWeighting() | ||
|
|
||
| for i, (input, target) in enumerate(zip(input_sequence, target_sequence, strict=True)): | ||
| memories_wrt.append(memory.detach().requires_grad_(True)) | ||
| memory = model(input.unsqueeze(0).to(dtype=torch.float32), memories_wrt[-1]) | ||
| output = head(memory) | ||
| loss = criterion(output, target.unsqueeze(0).to(dtype=torch.float32)) | ||
| memories.append(memory) | ||
|
|
||
| print(f"{loss.item():.1e}") | ||
|
|
||
| if (i + 1) % update_every == 0: | ||
| optimizer.zero_grad() | ||
|
|
||
| grad_output = torch.autograd.grad(loss, [memories[-1]]) | ||
|
|
||
| for j in range(update_every): | ||
| print(j) | ||
| grads = torch.autograd.grad( | ||
| memories[-j - 1], | ||
| list(model.parameters()) + [memories_wrt[-j - 1]], | ||
| grad_outputs=grad_output, | ||
| ) | ||
| grads_wrt_params = grads[:-1] | ||
| grad_output = grads[-1] | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it could be possible to clone the parameters of the memory model at each call, it should not require more memory. But then if we do backward we obtain a grad for each of the copies, we can stack them. Of course this also works and later on we can also make this quite efficient with hooks. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess in this code, there is no training at all? (no .grad=...)
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I think the current method is almost maximally efficient. But maybe it's not expressive enough (can't really select paths of length 1, 2, 4, 8, etc, without computing also 3, 5, 6, 7, ..., for now).
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could maybe do what you say with a detached view of the parameters (I think cloning duplicates memory + is differentiable so the gradients would flow back to the original params) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Selecting only paths is doable only with residual RNN. But note that if you select only path to level 2 memory, then you don't train interaction between level 1 and level 2, which is not typically what we want to do. |
||
|
|
||
| for param, grad in zip(model.parameters(), grads_wrt_params, strict=True): | ||
| param_to_gradients[param].append(grad) | ||
|
|
||
| param_to_jacobian_matrix = { | ||
| param: torch.stack([g.flatten() for g in gradients], dim=0) | ||
| for param, gradients in param_to_gradients.items() | ||
| } | ||
| jacobian_matrix = torch.cat([mat for mat in param_to_jacobian_matrix.values()], dim=1) | ||
|
|
||
| gramian = jacobian_matrix @ jacobian_matrix.T | ||
| weights = weighting(gramian) | ||
| # print(jacobian_matrix.shape) | ||
| print(gramian) | ||
| print(weights) | ||
|
|
||
| # graph = make_dot(loss, params=dict(model.named_parameters()), show_attrs=True, show_saved=True) | ||
| # graph.view() | ||
|
|
||
| # graph = make_dot(attached_memories[-1], params=dict(model.named_parameters()), show_attrs=True, | ||
| # show_saved=True) | ||
| # graph.view() | ||
|
|
||
| # loss.backward() | ||
|
|
||
| # print("fc1 weights: ", model.fc1.weight.grad) | ||
| # print("fc1 biases: ", model.fc1.bias.grad) | ||
| # | ||
| # print("fc2 weights: ", model.fc2.weight.grad) | ||
| # print("fc2 biases: ", model.fc2.bias.grad) | ||
|
|
||
| optimizer.step() | ||
| memory = memory.detach() | ||
Uh oh!
There was an error while loading. Please reload this page.