Skip to content

Commit a6df540

Browse files
committed
feat: Remove ACT in TRM
1 parent 97f9071 commit a6df540

File tree

1 file changed

+0
-33
lines changed
  • src/recursion/models/recursive_reasoning

1 file changed

+0
-33
lines changed

src/recursion/models/recursive_reasoning/trm.py

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -356,39 +356,6 @@ def forward(
356356

357357
halted = is_last_step
358358

359-
# if training, and ACT is enabled
360-
if self.training and (self.config.halt_max_steps > 1):
361-
362-
# Halt signal
363-
# NOTE: During evaluation, always use max steps, this is to guarantee the same halting steps inside a batch for batching purposes
364-
365-
if self.config.no_ACT_continue:
366-
halted = halted | (q_halt_logits > 0)
367-
else:
368-
halted = halted | (q_halt_logits > q_continue_logits)
369-
370-
# Exploration
371-
min_halt_steps = (
372-
torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob
373-
) * torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
374-
halted = halted & (new_steps >= min_halt_steps)
375-
376-
if not self.config.no_ACT_continue:
377-
# Compute target Q
378-
# NOTE: No replay buffer and target networks for computing target Q-value.
379-
# As batch_size is large, there're many parallel envs.
380-
# Similar concept as PQN https://arxiv.org/abs/2407.04811
381-
_, _, (next_q_halt_logits, next_q_continue_logits), _, _ = self.inner(
382-
new_inner_carry, new_current_data
383-
)
384-
outputs["target_q_continue"] = torch.sigmoid(
385-
torch.where(
386-
is_last_step,
387-
next_q_halt_logits,
388-
torch.maximum(next_q_halt_logits, next_q_continue_logits),
389-
)
390-
)
391-
392359
return (
393360
TinyRecursiveReasoningModel_ACTV1Carry(
394361
new_inner_carry, new_steps, halted, new_current_data

0 commit comments

Comments
 (0)