forked from SamsungSAILMontreal/TinyRecursiveModels
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy patharc.py
More file actions
200 lines (159 loc) · 6.54 KB
/
arc.py
File metadata and controls
200 lines (159 loc) · 6.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
import json
import os
from typing import Dict, Optional, Sequence
import numpy as np
import torch
import torch.distributed as dist
from numba import njit
from recursion.dataset.build_arc_dataset import arc_grid_to_np, grid_hash, inverse_aug
from recursion.dataset.common import PuzzleDatasetMetadata
@njit
def _crop(grid: np.ndarray):
"""Find maximum-sized rectangle without any EOS token inside."""
grid = grid.reshape(30, 30)
max_area = 0
max_size = (0, 0)
nr, nc = grid.shape
num_c = nc
for num_r in range(1, nr + 1):
# Scan for maximum c
for c in range(1, num_c + 1):
x = grid[num_r - 1, c - 1]
if (x < 2) | (x > 11):
num_c = c - 1
break
area = num_r * num_c
if area > max_area:
max_area = area
max_size = (num_r, num_c)
return (grid[: max_size[0], : max_size[1]] - 2).astype(np.uint8)
class ARC:
required_outputs = {"inputs", "puzzle_identifiers", "q_halt_logits", "preds"}
def __init__(
self,
data_path: str,
eval_metadata: PuzzleDatasetMetadata,
submission_K: int = 2,
pass_Ks: Sequence[int] = (1, 2, 5, 10, 100, 1000),
aggregated_voting: bool = True,
):
super().__init__()
self.pass_Ks = pass_Ks
self.submission_K = submission_K
self.aggregated_voting = aggregated_voting
self.blank_identifier_id = eval_metadata.blank_identifier_id
# Load identifiers and test puzzles
with open(os.path.join(data_path, "identifiers.json"), "r") as f:
self.identifier_map = json.load(f)
with open(os.path.join(data_path, "test_puzzles.json"), "r") as f:
self.test_puzzles = json.load(f)
# States
self._local_hmap = {}
self._local_preds = {}
def begin_eval(self):
if not self.aggregated_voting:
# Clear previous predictions
self._local_hmap = {}
self._local_preds = {}
def update_batch(self, batch: Dict[str, torch.Tensor], preds: Dict[str, torch.Tensor]):
# Collect required outputs to CPU
outputs = {}
q_values = None
for collection in (batch, preds):
for k, v in collection.items():
if k in self.required_outputs:
if k == "q_halt_logits":
q_values = v.to(torch.float64).sigmoid().cpu()
else:
outputs[k] = v.cpu()
assert q_values is not None
# Remove padding from outputs
mask = outputs["puzzle_identifiers"] != self.blank_identifier_id
outputs = {k: v[mask] for k, v in outputs.items()}
# Get predictions
for identifier, input, pred, q in zip(
outputs["puzzle_identifiers"].numpy(),
outputs["inputs"].numpy(),
outputs["preds"].numpy(),
q_values.numpy(),
):
name = self.identifier_map[identifier]
orig_name, _inverse_fn = inverse_aug(name)
input_hash = grid_hash(_inverse_fn(_crop(input)))
pred = _inverse_fn(_crop(pred))
assert np.all(
(pred >= 0) & (pred <= 9)
), f"Puzzle {name}'s prediction out of 0-9 range." # Sanity check
# Store into local state
pred_hash = grid_hash(pred)
self._local_hmap[pred_hash] = pred
self._local_preds.setdefault(orig_name, {})
self._local_preds[orig_name].setdefault(input_hash, [])
self._local_preds[orig_name][input_hash].append((pred_hash, float(q)))
def result(
self,
save_path: Optional[str],
rank: int,
world_size: int,
group: Optional[torch.distributed.ProcessGroup] = None,
) -> Optional[Dict[str, float]]:
# Gather predictions to rank 0 for voting
global_hmap_preds = [None for _ in range(world_size)] if rank == 0 else None
dist.gather_object(
(self._local_hmap, self._local_preds), global_hmap_preds, dst=0, group=group
)
# Rank 0 logic
if rank != 0:
return
submission = {}
correct = [0.0 for _ in range(len(self.pass_Ks))]
for name, puzzle in self.test_puzzles.items():
# Process test examples in this puzzle
submission[name] = []
num_test_correct = [0 for _ in range(len(self.pass_Ks))]
for pair in puzzle["test"]:
input_hash = grid_hash(arc_grid_to_np(pair["input"]))
label_hash = grid_hash(arc_grid_to_np(pair["output"]))
p_map = {}
for hmap, preds in global_hmap_preds: # type: ignore
for h, q in preds.get(name, {}).get(input_hash, {}):
p_map.setdefault(h, [0, 0])
p_map[h][0] += 1
p_map[h][1] += q
if not len(p_map):
print(f"Puzzle {name} has no predictions.")
continue
for h, stats in p_map.items():
stats[1] /= stats[0]
p_map = sorted(p_map.items(), key=lambda kv: kv[1], reverse=True)
# vote for different Ks
for i, k in enumerate(self.pass_Ks):
ok = False
for h, stats in p_map[:k]:
ok |= h == label_hash
num_test_correct[i] += ok
# Query grids
pred_grids = []
for h, stats in p_map[: self.submission_K]:
for hmap, preds in global_hmap_preds: # type: ignore
if h in hmap:
pred_grids.append(hmap[h])
break
# Pad to K
while len(pred_grids) < self.submission_K:
pred_grids.append(pred_grids[0])
submission[name].append(
{f"attempt_{i + 1}": grid.tolist() for i, grid in enumerate(pred_grids)}
)
# Total correctness
for i in range(len(self.pass_Ks)):
correct[i] += num_test_correct[i] / len(puzzle["test"])
# Save submission
if save_path is not None:
with open(os.path.join(save_path, "submission.json"), "w") as f:
json.dump(submission, f)
# Final result
all_results = {
f"ARC/pass@{k}": correct[i] / len(self.test_puzzles) for i, k in enumerate(self.pass_Ks)
}
return all_results