@@ -356,39 +356,6 @@ def forward(
356356
357357 halted = is_last_step
358358
359- # if training, and ACT is enabled
360- if self .training and (self .config .halt_max_steps > 1 ):
361-
362- # Halt signal
363- # NOTE: During evaluation, always use max steps, this is to guarantee the same halting steps inside a batch for batching purposes
364-
365- if self .config .no_ACT_continue :
366- halted = halted | (q_halt_logits > 0 )
367- else :
368- halted = halted | (q_halt_logits > q_continue_logits )
369-
370- # Exploration
371- min_halt_steps = (
372- torch .rand_like (q_halt_logits ) < self .config .halt_exploration_prob
373- ) * torch .randint_like (new_steps , low = 2 , high = self .config .halt_max_steps + 1 )
374- halted = halted & (new_steps >= min_halt_steps )
375-
376- if not self .config .no_ACT_continue :
377- # Compute target Q
378- # NOTE: No replay buffer and target networks for computing target Q-value.
379- # As batch_size is large, there're many parallel envs.
380- # Similar concept as PQN https://arxiv.org/abs/2407.04811
381- _ , _ , (next_q_halt_logits , next_q_continue_logits ), _ , _ = self .inner (
382- new_inner_carry , new_current_data
383- )
384- outputs ["target_q_continue" ] = torch .sigmoid (
385- torch .where (
386- is_last_step ,
387- next_q_halt_logits ,
388- torch .maximum (next_q_halt_logits , next_q_continue_logits ),
389- )
390- )
391-
392359 return (
393360 TinyRecursiveReasoningModel_ACTV1Carry (
394361 new_inner_carry , new_steps , halted , new_current_data
0 commit comments