import time import logging import psutil import torch def print_rank_0(message): if torch.distributed.is_initialized(): if torch.distributed.get_rank() == 0: print(message, flush=True) else: print(message, flush=True) class SynchronizedWallClockTimer: """Group of timers. Borrowed from Nvidia Megatron code""" class Timer: """Timer.""" def __init__(self, name): self.name_ = name self.elapsed_ = 0.0 self.started_ = False self.start_time = time.time() def start(self): """Start the timer.""" assert not self.started_, 'timer has already been started' torch.cuda.synchronize() self.start_time = time.time() self.started_ = True def stop(self): """Stop the timer.""" assert self.started_, 'timer is not started' torch.cuda.synchronize() self.elapsed_ += (time.time() - self.start_time) self.started_ = False def reset(self): """Reset timer.""" self.elapsed_ = 0.0 self.started_ = False def elapsed(self, reset=True): """Calculate the elapsed time.""" started_ = self.started_ # If the timing in progress, end it first. if self.started_: self.stop() # Get the elapsed time. elapsed_ = self.elapsed_ # Reset the elapsed time if reset: self.reset() # If timing was in progress, set it back. if started_: self.start() return elapsed_ def __init__(self): self.timers = {} def __call__(self, name): if name not in self.timers: self.timers[name] = self.Timer(name) return self.timers[name] def log(self, names, normalizer=1.0, reset=True): """Log a group of timers.""" assert normalizer > 0.0 string = 'time (ms)' for name in names: elapsed_time = self.timers[name].elapsed( reset=reset) * 1000.0 / normalizer string += ' | {}: {:.2f}'.format(name, elapsed_time) print_rank_0(string) class ThroughputTimer(object): def __init__(self, name=None, batch_size=1, num_workers=1, start_step=2): self.start_time = 0 self.end_time = 0 self.started = False self.count = 0 self.total_elapsed_time = 0 self.batch_size = batch_size self.num_workers = num_workers self.start_step = start_step self.name = name def start(self, cond=True): if cond: self.start_time = time.time() self.started = True def stop(self, cond=True): if cond and self.started: self.end_time = time.time() self.started = False self.count += 1 if self.count >= self.start_step: self.total_elapsed_time += self.end_time - self.start_time elif cond and not self.started: print("Cannot stop timer without starting ") exit(0) def avg_samples_per_sec(self): if self.count > 2: samples_per_step = self.batch_size * self.num_workers avg_time_per_step = self.total_elapsed_time / (self.count - 2.0) # training samples per second return samples_per_step / avg_time_per_step return -999 def avg_steps_per_sec(self): if self.count > 2: return 1 / (self.total_elapsed_time / (self.count - 2.0)) return -999 def print_elapsed_time(self, num_ops=None): if self.count > 2 and self.count % 1000 == 0: elapsed_time = self.total_elapsed_time / (self.count - 2.0) if num_ops == None: print(self.name, " forward pass execution time: ", elapsed_time) else: print(self.name, " forward pass execution time: ", elapsed_time, " TFlops : ", num_ops / (elapsed_time * 1000000000000))