Skip to content

Commit 8120f2b

Browse files
committed
upload
0 parents  commit 8120f2b

39 files changed

Lines changed: 27428 additions & 0 deletions

LICENSE

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
MIT License
2+
3+
Copyright (c) 2025. Samsung Electronics Co., Ltd. All Rights Reserved.
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in all
13+
copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
SOFTWARE.

README.md

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
# Less is More: Recursive Reasoning with Tiny Networks
2+
3+
This is the codebase for the paper: "Less is More: Recursive Reasoning with Tiny Networks" were we present a recursive reasoning approach that achieves amazing scores of 45% on ARC-AGI-1 and 8% on ARC-AGI-2 using a tiny 7M parameters neural network.
4+
5+
[Paper](https://arxiv.org/abs/2510.04871)
6+
7+
### How TRM works
8+
9+
Tiny Recursion Model (TRM) recursively improves its predicted answer y with a tiny network. It starts with the embedded input question x and initial embedded answer y and latent z. For up to K improvements steps, it tries to improve its answer y. It does so by i) recursively updating n times its latent z given the question x, current answer y, and current latent z (recursive reasoning), and then ii) updating its answer y given the current answer y and current latent z. This recursive process allows the model to progressively improve its answer (potentially addressing any errors from its previous answer) in an extremely parameter-efficient manner while minimizing overfitting.
10+
11+
<p align="center">
12+
<img src="{{ site.baseurl }}/assets/images/TRM_fig.png" alt="TRM-Figure" style="width:50%">
13+
</p>
14+
15+
### Requirements
16+
17+
- Python 3.10 (or similar)
18+
- Cuda 12.6.0 (or similar)
19+
20+
```bash
21+
pip install --upgrade pip wheel setuptools
22+
pip install --pre --upgrade torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu126 # install torch based on your cuda version
23+
pip install -r requirements.txt # install requirements
24+
pip install --no-cache-dir --no-build-isolation adam-atan2
25+
wandb login YOUR-LOGIN # login if you want the logger to sync results to your Weights & Biases (https://wandb.ai/)
26+
```
27+
28+
### Dataset Preparation
29+
30+
```bash
31+
# ARC-AGI-1
32+
python -m dataset.build_arc_dataset \
33+
--input-file-prefix kaggle/combined/arc-agi \
34+
--output-dir data/arc1concept-aug-1000 \
35+
--subsets training evaluation concept \
36+
--test-set-name evaluation
37+
38+
# ARC-AGI-2
39+
python -m dataset.build_arc_dataset \
40+
--input-file-prefix kaggle/combined/arc-agi \
41+
--output-dir data/arc2concept-aug-1000 \
42+
--subsets training2 evaluation2 concept \
43+
--test-set-name evaluation2
44+
45+
## Note: You cannot train on both ARC-AGI-1 and ARC-AGI-2 and evaluate them both because ARC-AGI-2 training data contains some ARC-AGI-1 eval data
46+
47+
# Sudoku-Extreme
48+
python dataset/build_sudoku_dataset.py --output-dir data/sudoku-extreme-1k-aug-1000 --subsample-size 1000 --num-aug 1000 # 1000 examples, 1000 augments
49+
50+
# Maze-Hard
51+
python dataset/build_maze_dataset.py # 1000 examples, 8 augments
52+
```
53+
54+
## Experiments
55+
56+
### ARC-AGI (assuming 4 H-100 GPUs):
57+
58+
```bash
59+
run_name="pretrain_att_arc12concept_4"
60+
torchrun --nproc-per-node 4 --rdzv_backend=c10d --rdzv_endpoint=localhost:0 --nnodes=1 pretrain.py \
61+
arch=trm \
62+
data_paths="[data/arc12concept-aug-1000]" \
63+
arch.L_layers=2 \
64+
arch.H_cycles=3 arch.L_cycles=4 \
65+
+run_name=${run_name} ema=True
66+
67+
```
68+
69+
*Runtime:* ~3 days
70+
71+
### Sudoku-Extreme (assuming 1 L40S GPU):
72+
73+
```bash
74+
run_name="pretrain_mlp_t_sudoku"
75+
python pretrain.py \
76+
arch=trm \
77+
data_paths="[data/sudoku-extreme-1k-aug-1000]" \
78+
evaluators="[]" \
79+
epochs=50000 eval_interval=5000 \
80+
lr=1e-4 puzzle_emb_lr=1e-4 weight_decay=1.0 puzzle_emb_weight_decay=1.0 \
81+
arch.mlp_t=True arch.pos_encodings=none \
82+
arch.L_layers=2 \
83+
arch.H_cycles=3 arch.L_cycles=6 \
84+
+run_name=${run_name} ema=True
85+
86+
run_name="pretrain_att_sudoku"
87+
python pretrain.py \
88+
arch=trm \
89+
data_paths="[data/sudoku-extreme-1k-aug-1000]" \
90+
evaluators="[]" \
91+
epochs=50000 eval_interval=5000 \
92+
lr=1e-4 puzzle_emb_lr=1e-4 weight_decay=1.0 puzzle_emb_weight_decay=1.0 \
93+
arch.L_layers=2 \
94+
arch.H_cycles=3 arch.L_cycles=6 \
95+
+run_name=${run_name} ema=True
96+
```
97+
98+
*Runtime:* < 36 hours
99+
100+
### Maze-Hard (assuming 4 L40S GPUs):
101+
102+
```bash
103+
run_name="pretrain_att_maze30x30"
104+
torchrun --nproc-per-node 4 --rdzv_backend=c10d --rdzv_endpoint=localhost:0 --nnodes=1 pretrain.py \
105+
arch=trm \
106+
data_paths="[data/maze-30x30-hard-1k]" \
107+
evaluators="[]" \
108+
epochs=50000 eval_interval=5000 \
109+
lr=1e-4 puzzle_emb_lr=1e-4 weight_decay=1.0 puzzle_emb_weight_decay=1.0 \
110+
arch.L_layers=2 \
111+
arch.H_cycles=3 arch.L_cycles=4 \
112+
+run_name=${run_name} ema=True
113+
```
114+
115+
*Runtime:* < 24 hours
116+
117+
## Reference
118+
119+
If you find our work useful, please consider citing:
120+
121+
```bibtex
122+
@misc{jolicoeurmartineau2025tinyrecursionmodel,
123+
title={Less is More: Recursive Reasoning with Tiny Networks},
124+
author={Alexia Jolicoeur-Martineau},
125+
year={2025},
126+
eprint={xxxxxxx},
127+
archivePrefix={arXiv},
128+
primaryClass={cs.AI},
129+
url={https://arxiv.org/abs/xxxxxxxxx},
130+
}
131+
```
132+
133+
and the Hierarchical Reasoning Model (HRM):
134+
135+
```bibtex
136+
@misc{wang2025hierarchicalreasoningmodel,
137+
title={Hierarchical Reasoning Model},
138+
author={Guan Wang and Jin Li and Yuhao Sun and Xing Chen and Changling Liu and Yue Wu and Meng Lu and Sen Song and Yasin Abbasi Yadkori},
139+
year={2025},
140+
eprint={2506.21734},
141+
archivePrefix={arXiv},
142+
primaryClass={cs.AI},
143+
url={https://arxiv.org/abs/2506.21734},
144+
}
145+
```
146+
147+
This code is based on the Hierarchical Reasoning Model [code](https://github.com/sapientinc/HRM) and the Hierarchical Reasoning Model Analysis [code](https://github.com/arcprize/hierarchical-reasoning-model-analysis).

assets/TRM_fig.png

346 KB
Loading

assets/TRM_pseudocode.png

261 KB
Loading

config/arch/hrm.yaml

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
name: recursive_reasoning.hrm@HierarchicalReasoningModel_ACTV1
2+
loss:
3+
name: losses@ACTLossHead
4+
loss_type: stablemax_cross_entropy
5+
6+
halt_exploration_prob: 0.1
7+
halt_max_steps: 16
8+
9+
H_cycles: 2
10+
L_cycles: 2
11+
12+
H_layers: 4
13+
L_layers: 4
14+
15+
hidden_size: 512
16+
num_heads: 8 # min(2, hidden_size // 64)
17+
expansion: 4
18+
19+
puzzle_emb_ndim: ${.hidden_size}
20+
21+
pos_encodings: rope
22+
forward_dtype: bfloat16
23+
24+
mlp_t: False # use mlp on L instead of transformer
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
name: recursive_reasoning.transformers_baseline@Model_ACTV2
2+
loss:
3+
name: losses@ACTLossHead
4+
loss_type: stablemax_cross_entropy
5+
6+
halt_exploration_prob: 0.1
7+
halt_max_steps: 16
8+
9+
H_cycles: 1 # kept for compatibility
10+
H_layers: 8
11+
12+
hidden_size: 512
13+
num_heads: 12
14+
expansion: 4
15+
16+
puzzle_emb_ndim: ${.hidden_size}
17+
18+
pos_encodings: rope

config/arch/trm.yaml

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
name: recursive_reasoning.trm@TinyRecursiveReasoningModel_ACTV1
2+
loss:
3+
name: losses@ACTLossHead
4+
loss_type: stablemax_cross_entropy
5+
6+
halt_exploration_prob: 0.1
7+
halt_max_steps: 16
8+
9+
H_cycles: 3
10+
L_cycles: 6
11+
12+
H_layers: 0
13+
L_layers: 2
14+
15+
hidden_size: 512
16+
num_heads: 8 # min(2, hidden_size // 64)
17+
expansion: 4
18+
19+
puzzle_emb_ndim: ${.hidden_size}
20+
21+
pos_encodings: rope
22+
forward_dtype: bfloat16
23+
24+
mlp_t: False # use mlp on L instead of transformer
25+
puzzle_emb_len: 16 # if non-zero, its specified to this value
26+
no_ACT_continue: True # No continue ACT loss, only use the sigmoid of the halt which makes much more sense

config/arch/trm_hier6.yaml

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
name: recursive_reasoning.trm_hier6@TinyRecursiveReasoningModel_ACTV1
2+
loss:
3+
name: losses@ACTLossHead
4+
loss_type: stablemax_cross_entropy
5+
6+
halt_exploration_prob: 0.1
7+
halt_max_steps: 16
8+
9+
H_cycles: 3
10+
L_cycles: 6
11+
12+
H_layers: 0
13+
L_layers: 2
14+
15+
hidden_size: 512
16+
num_heads: 8 # min(2, hidden_size // 64)
17+
expansion: 4
18+
19+
puzzle_emb_ndim: ${.hidden_size}
20+
21+
pos_encodings: rope
22+
forward_dtype: bfloat16
23+
24+
mlp_t: False # use mlp on L instead of transformer
25+
puzzle_emb_len: 16 # if non-zero, its specified to this value
26+
no_ACT_continue: True # No continue ACT loss, only use the sigmoid of the halt which makes much more sense

config/arch/trm_singlez.yaml

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
name: recursive_reasoning.trm_singlez@TinyRecursiveReasoningModel_ACTV1
2+
loss:
3+
name: losses@ACTLossHead
4+
loss_type: stablemax_cross_entropy
5+
6+
halt_exploration_prob: 0.1
7+
halt_max_steps: 16
8+
9+
H_cycles: 3
10+
L_cycles: 6
11+
12+
H_layers: 0
13+
L_layers: 2
14+
15+
hidden_size: 512
16+
num_heads: 8 # min(2, hidden_size // 64)
17+
expansion: 4
18+
19+
puzzle_emb_ndim: ${.hidden_size}
20+
21+
pos_encodings: rope
22+
forward_dtype: bfloat16
23+
24+
mlp_t: False # use mlp on L instead of transformer
25+
puzzle_emb_len: 16 # if non-zero, its specified to this value
26+
no_ACT_continue: True # No continue ACT loss, only use the sigmoid of the halt which makes much more sense

config/cfg_pretrain.yaml

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# ARC training config
2+
3+
defaults:
4+
- arch: trm
5+
- _self_
6+
7+
hydra:
8+
output_subdir: null
9+
10+
# Data path
11+
data_paths: ['data/arc-aug-1000']
12+
data_paths_test: []
13+
14+
evaluators:
15+
- name: arc@ARC
16+
17+
# Hyperparams - Training
18+
global_batch_size: 768
19+
20+
epochs: 100000
21+
eval_interval: 10000
22+
checkpoint_every_eval: True
23+
24+
lr: 1e-4
25+
lr_min_ratio: 1.0
26+
lr_warmup_steps: 2000
27+
28+
# Standard hyperparameter settings for LM, as used in Llama
29+
beta1: 0.9
30+
beta2: 0.95
31+
weight_decay: 0.1
32+
puzzle_emb_weight_decay: 0.1
33+
34+
# Hyperparams - Puzzle embeddings training
35+
puzzle_emb_lr: 1e-2
36+
37+
seed: 0
38+
min_eval_interval: 0 # when to start the eval
39+
40+
ema: False # use Exponential-Moving-Average
41+
ema_rate: 0.999 # EMA-rate
42+
freeze_weights: False # If True, freeze weights and only learn the embeddings

0 commit comments

Comments
 (0)