@@ -867,7 +867,10 @@ def fit_batch(self, batch):
867867 with torch .cuda .amp .autocast ():
868868 outputs = self .compute_forward (batch , Stage .TRAIN )
869869 loss = self .compute_objectives (outputs , batch , Stage .TRAIN )
870- self .scaler .scale (loss / self .grad_accumulation_factor ).backward ()
870+ with self .no_sync (not should_step ):
871+ self .scaler .scale (
872+ loss / self .grad_accumulation_factor
873+ ).backward ()
871874 if should_step :
872875 self .scaler .unscale_ (self .optimizer )
873876 if self .check_gradients (loss ):
@@ -877,7 +880,8 @@ def fit_batch(self, batch):
877880 else :
878881 outputs = self .compute_forward (batch , Stage .TRAIN )
879882 loss = self .compute_objectives (outputs , batch , Stage .TRAIN )
880- (loss / self .grad_accumulation_factor ).backward ()
883+ with self .no_sync (not should_step ):
884+ (loss / self .grad_accumulation_factor ).backward ()
881885 if should_step :
882886 if self .check_gradients (loss ):
883887 self .optimizer .step ()
@@ -888,7 +892,20 @@ def fit_batch(self, batch):
888892 return loss .detach ().cpu ()
889893
890894 def on_fit_batch_end (self , batch , outputs , loss , should_step ):
891- """Called after ``fit_batch()``"""
895+ """Called after ``fit_batch()``, meant for calculating and logging metrics.
896+
897+ Arguments
898+ ---------
899+ batch : list of torch.Tensors
900+ Batch of data to use for training. Default implementation assumes
901+ this batch has two elements: inputs and targets.
902+ outputs : list or dictionary of torch.Tensors
903+ Returned value of compute_forward().
904+ loss : torch.Tensor
905+ Returned value of compute_objectives().
906+ should_step : boolean
907+ Whether optimizer.step() was called or not.
908+ """
892909 pass
893910
894911 def check_gradients (self , loss ):
0 commit comments