Skip to content

Commit 0b0ec9d

Browse files
committed
using no_sync() in fit_batch() of core.py
1 parent 9077820 commit 0b0ec9d

1 file changed

Lines changed: 20 additions & 3 deletions

File tree

speechbrain/core.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -867,7 +867,10 @@ def fit_batch(self, batch):
867867
with torch.cuda.amp.autocast():
868868
outputs = self.compute_forward(batch, Stage.TRAIN)
869869
loss = self.compute_objectives(outputs, batch, Stage.TRAIN)
870-
self.scaler.scale(loss / self.grad_accumulation_factor).backward()
870+
with self.no_sync(not should_step):
871+
self.scaler.scale(
872+
loss / self.grad_accumulation_factor
873+
).backward()
871874
if should_step:
872875
self.scaler.unscale_(self.optimizer)
873876
if self.check_gradients(loss):
@@ -877,7 +880,8 @@ def fit_batch(self, batch):
877880
else:
878881
outputs = self.compute_forward(batch, Stage.TRAIN)
879882
loss = self.compute_objectives(outputs, batch, Stage.TRAIN)
880-
(loss / self.grad_accumulation_factor).backward()
883+
with self.no_sync(not should_step):
884+
(loss / self.grad_accumulation_factor).backward()
881885
if should_step:
882886
if self.check_gradients(loss):
883887
self.optimizer.step()
@@ -888,7 +892,20 @@ def fit_batch(self, batch):
888892
return loss.detach().cpu()
889893

890894
def on_fit_batch_end(self, batch, outputs, loss, should_step):
891-
"""Called after ``fit_batch()``"""
895+
"""Called after ``fit_batch()``, meant for calculating and logging metrics.
896+
897+
Arguments
898+
---------
899+
batch : list of torch.Tensors
900+
Batch of data to use for training. Default implementation assumes
901+
this batch has two elements: inputs and targets.
902+
outputs : list or dictionary of torch.Tensors
903+
Returned value of compute_forward().
904+
loss : torch.Tensor
905+
Returned value of compute_objectives().
906+
should_step : boolean
907+
Whether optimizer.step() was called or not.
908+
"""
892909
pass
893910

894911
def check_gradients(self, loss):

0 commit comments

Comments
 (0)