forked from SamsungSAILMontreal/TinyRecursiveModels
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathlosses.py
More file actions
125 lines (101 loc) · 4.16 KB
/
losses.py
File metadata and controls
125 lines (101 loc) · 4.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from typing import Any, Dict, Optional, Sequence, Tuple
import torch
import torch.nn.functional as F
from torch import nn
IGNORE_LABEL_ID = -100
def s(x, epsilon=1e-30):
return torch.where(x < 0, 1 / (1 - x + epsilon), x + 1)
def log_stablemax(x, dim=-1):
s_x = s(x)
return torch.log(s_x / torch.sum(s_x, dim=dim, keepdim=True))
def stablemax_cross_entropy(logits, labels, ignore_index: int = -100, valid_mask=None):
logprobs = log_stablemax(logits.to(torch.float64), dim=-1)
if valid_mask is None:
valid_mask = labels != ignore_index
transformed_labels = torch.where(valid_mask, labels, 0)
prediction_logprobs = torch.gather(
logprobs, index=transformed_labels.to(torch.long).unsqueeze(-1), dim=-1
).squeeze(-1)
return -torch.where(valid_mask, prediction_logprobs, 0)
def softmax_cross_entropy(logits, labels, ignore_index: int = -100):
# Cast logits to f32
# Flatten logits
return F.cross_entropy(
logits.to(torch.float32).view(-1, logits.shape[-1]),
labels.to(torch.long).view(-1),
ignore_index=ignore_index,
reduction="none",
).view(labels.shape)
class ACTLossHead(nn.Module):
def __init__(self, model: nn.Module, loss_type: str):
super().__init__()
self.model = model
self.loss_fn = globals()[loss_type]
def initial_carry(self, *args, **kwargs):
return self.model.initial_carry(*args, **kwargs) # type: ignore
def forward(
self,
return_keys: Sequence[str],
# Model args
**model_kwargs,
) -> Tuple[
Any, torch.Tensor, Dict[str, torch.Tensor], Optional[Dict[str, torch.Tensor]], torch.Tensor
]:
# Model logits
# B x SeqLen x D
new_carry, outputs = self.model(**model_kwargs)
labels = new_carry.current_data["labels"]
with torch.no_grad():
# Preds
outputs["preds"] = torch.argmax(outputs["logits"], dim=-1)
# Correctness
mask = labels != IGNORE_LABEL_ID
loss_counts = mask.sum(-1)
loss_divisor = loss_counts.clamp_min(1).unsqueeze(-1) # Avoid NaNs in division
is_correct = mask & (torch.argmax(outputs["logits"], dim=-1) == labels)
seq_is_correct = is_correct.sum(-1) == loss_counts
# Metrics (halted)
valid_metrics = new_carry.halted & (loss_counts > 0)
metrics = {
"count": valid_metrics.sum(),
"accuracy": torch.where(
valid_metrics, (is_correct.to(torch.float32) / loss_divisor).sum(-1), 0
).sum(),
"exact_accuracy": (valid_metrics & seq_is_correct).sum(),
"q_halt_accuracy": (
valid_metrics & ((outputs["q_halt_logits"] >= 0) == seq_is_correct)
).sum(),
"steps": torch.where(valid_metrics, new_carry.steps, 0).sum(),
}
# Losses
lm_loss = (
self.loss_fn(outputs["logits"], labels, ignore_index=IGNORE_LABEL_ID, valid_mask=mask)
/ loss_divisor
).sum()
q_halt_loss = F.binary_cross_entropy_with_logits(
outputs["q_halt_logits"],
seq_is_correct.to(outputs["q_halt_logits"].dtype),
reduction="sum",
)
metrics.update(
{
"lm_loss": lm_loss.detach(),
"q_halt_loss": q_halt_loss.detach(),
}
)
# Q continue (bootstrapping target loss); Alexia: This fits Q-learning, but seems totally unecessary
q_continue_loss = 0
if "target_q_continue" in outputs:
q_continue_loss = F.binary_cross_entropy_with_logits(
outputs["q_continue_logits"], outputs["target_q_continue"], reduction="sum"
)
metrics["q_continue_loss"] = q_continue_loss.detach()
# Filter outputs for return
detached_outputs = {k: outputs[k].detach() for k in return_keys if k in outputs}
return (
new_carry,
lm_loss + 0.5 * (q_halt_loss + q_continue_loss),
metrics,
detached_outputs,
new_carry.halted.all(),
)