1515import torch
1616
1717MAIN_PROC_ONLY : int = 0
18+ NODE_ONCE_ONLY : int = 0
1819
1920
2021def rank_prefixed_message (message : str ) -> str :
@@ -58,12 +59,29 @@ def get_rank() -> Optional[int]:
5859 return None
5960
6061
62+ def get_local_rank () -> Optional [int ]:
63+ r"""Get the local rank of the current process on the current node.
64+
65+ Returns
66+ -------
67+ int or None
68+ The local rank of the current process, or None if the local rank could not be determined.
69+ """
70+ rank_keys = ("LOCAL_RANK" , "SLURM_LOCALID" )
71+ for key in rank_keys :
72+ rank = os .environ .get (key )
73+ if rank is not None :
74+ return int (rank )
75+ # None to differentiate whether an environment variable was set at all
76+ return None
77+
78+
6179def infer_device () -> str :
6280 """Make a basic guess about intended running device based on
6381 availability and distributed environment variable 'LOCAL_RANK'"""
6482 if torch .cuda .is_available ():
6583 device = "cuda"
66- local_rank = os . environ . get ( "LOCAL_RANK" )
84+ local_rank = get_local_rank ( )
6785 if local_rank is not None :
6886 device += f":{ local_rank } "
6987 else :
@@ -136,6 +154,94 @@ def run_on_main(
136154 return result
137155
138156
157+ def run_once_per_node (
158+ func ,
159+ args = None ,
160+ kwargs = None ,
161+ post_func = None ,
162+ post_args = None ,
163+ post_kwargs = None ,
164+ run_post_on_all = False ,
165+ ):
166+ r"""Runs a function with DPP (multi-gpu) support.
167+
168+ The provided function `func` is only run once on each node, while other processes
169+ block to wait for the function execution to finish. This is useful for things such
170+ as saving a file to the disk on each separate node (i.e. the filesystems are separate).
171+ In addition, a second function can be specified to be run on other processes after the
172+ first function completes, for example, loading a file that was created on each node.
173+
174+ Arguments
175+ ---------
176+ func : callable
177+ Function to be run once on each node.
178+ args : list, None
179+ Positional args to pass to func.
180+ kwargs : dict, None
181+ Keyword args to pass to func.
182+ post_func : callable, None
183+ Function to run after `func` has finished. By default, `post_func` is not run
184+ on the process that ran `func`.
185+ post_args : list, None
186+ Positional args to pass to post_func.
187+ post_kwargs : dict, None
188+ Keyword args to pass to post_func.
189+ run_post_on_all : bool
190+ Whether to run post_func on all processes, including the process that ran `func`.
191+
192+ Returns
193+ -------
194+ If `post_func` is provided, returns the result on all processes where `post_func` is run.
195+ If `run_post_on_all` is `False` or `post_func` is not provided, returns the result of `func` on the processes where it is run.
196+ If `post_func` is not provided, returns `None` on processes where `func` was not called.
197+
198+ Example
199+ -------
200+ >>> tmpfile = getfixture("tmpdir") / "example.pt"
201+ >>> # Return tensor so we don't have to load it on the saving process
202+ >>> def save_and_return(file, tensor):
203+ ... torch.save(tensor, file)
204+ ... return tensor
205+ >>> # Load tensor on non-saving processes
206+ >>> def load_tensor(file):
207+ ... return torch.load(file)
208+ >>> # Save on node-primary processes, load on others
209+ >>> example_tensor = torch.ones(5)
210+ >>> loaded_tensor = run_once_per_node(
211+ ... func=save_and_return,
212+ ... args=[tmpfile, example_tensor],
213+ ... post_func=load_tensor,
214+ ... post_args=[tmpfile],
215+ ... run_post_on_all=False,
216+ ... )
217+ >>> # We should get the same result on all processes
218+ >>> loaded_tensor
219+ tensor([1., 1., 1., 1., 1.])
220+ """
221+ # Handle the mutable data types' default args:
222+ args = args or []
223+ kwargs = kwargs or {}
224+ post_args = post_args or []
225+ post_kwargs = post_kwargs or {}
226+
227+ # Call the function exactly once per node, wait on other processes
228+ result = once_per_node (func )(* args , ** kwargs )
229+ ddp_barrier ()
230+
231+ # Call the post function if provided
232+ if post_func is not None :
233+ if run_post_on_all :
234+ # Just run on every process without any barrier.
235+ result = post_func (* post_args , ** post_kwargs )
236+ else :
237+ # Do the opposite of `once_per_node` and await result
238+ if not is_local_rank_zero ():
239+ result = post_func (* post_args , ** post_kwargs )
240+ ddp_barrier ()
241+
242+ return result
243+
244+
139245def is_distributed_initialized () -> bool :
140246 r"Returns whether the current system is distributed."
141247 # `is_initialized` is only defined conditionally
@@ -148,10 +254,12 @@ def is_distributed_initialized() -> bool:
148254
149255def if_main_process () -> bool :
150256 r"Returns whether the current process is the main process."
151- if is_distributed_initialized ():
152- return torch .distributed .get_rank () == 0
153- else :
154- return True
257+ return not is_distributed_initialized () or get_rank () == 0
258+
259+
260+ def is_local_rank_zero () -> bool :
261+ r"Returns whether the current process has local rank of 0."
262+ return not is_distributed_initialized () or get_local_rank () == 0
155263
156264
157265class MainProcessContext :
@@ -174,6 +282,26 @@ def __exit__(self, exc_type, exc_value, traceback):
174282 MAIN_PROC_ONLY -= 1
175283
176284
285+ class OncePerNodeContext :
286+ r"""
287+ Context manager to ensure code runs only once per node.
288+ This is useful to make sure that `NODE_ONCE_ONLY` global variable
289+ is decreased even if there's an exception raised inside of the
290+ `once_per_node_wrapped_fn` function.
291+ """
292+
293+ def __enter__ (self ):
294+ r"""Enter the context. Increase the counter."""
295+ global NODE_ONCE_ONLY
296+ NODE_ONCE_ONLY += 1
297+ return self
298+
299+ def __exit__ (self , exc_type , exc_value , traceback ):
300+ r"""Exit the context. Decrease the counter."""
301+ global NODE_ONCE_ONLY
302+ NODE_ONCE_ONLY -= 1
303+
304+
177305def main_process_only (function ):
178306 r"""Function decorator to ensure the function runs only on the main process.
179307 This is useful for things like saving to the filesystem or logging
@@ -195,6 +323,37 @@ def main_proc_wrapped_func(*args, **kwargs):
195323 return main_proc_wrapped_func
196324
197325
326+ def once_per_node (function ):
327+ r"""Function decorator to ensure the function runs only once per node.
328+ This is useful for things like saving to the filesystem
329+ where you only want it to happen on a single process on each node.
330+
331+ Unlike `main_process_only`, no broadcasting is done. Instead, processes
332+ with local_rank == 0 keep their own result, all other processes
333+ return None.
334+ """
335+
336+ @wraps (function )
337+ def once_per_node_wrapped_fn (* args , ** kwargs ):
338+ """This decorated function runs only if this is the main process."""
339+ with OncePerNodeContext ():
340+ if is_local_rank_zero ():
341+ return function (* args , ** kwargs )
342+ else :
343+ return None
344+
345+ return once_per_node_wrapped_fn
346+
347+
348+ def ddp_prevent_block ():
349+ r"Prevent blocking because only one or partial threads running."
350+ return (
351+ MAIN_PROC_ONLY >= 1
352+ or NODE_ONCE_ONLY >= 1
353+ or not is_distributed_initialized ()
354+ )
355+
356+
198357def ddp_barrier ():
199358 r"""
200359 Synchronize all processes in distributed data parallel (DDP) mode.
@@ -216,7 +375,7 @@ def ddp_barrier():
216375 >>> print("hello world")
217376 hello world
218377 """
219- if MAIN_PROC_ONLY >= 1 or not is_distributed_initialized ():
378+ if ddp_prevent_block ():
220379 return
221380
222381 if torch .distributed .get_backend () == torch .distributed .Backend .NCCL :
@@ -241,7 +400,7 @@ def ddp_broadcast(communication_object, src=0):
241400 -------
242401 The communication_object passed on rank src.
243402 """
244- if MAIN_PROC_ONLY >= 1 or not is_distributed_initialized ():
403+ if ddp_prevent_block ():
245404 return communication_object
246405
247406 # Wrapping object in a list is required for preventing
@@ -271,7 +430,7 @@ def ddp_all_reduce(communication_object, reduce_op):
271430 """
272431
273432 # If DDP not initialised or executed with a main process barrier
274- if MAIN_PROC_ONLY >= 1 or not is_distributed_initialized ():
433+ if ddp_prevent_block ():
275434 return communication_object
276435
277436 torch .distributed .all_reduce (communication_object , op = reduce_op )
@@ -296,12 +455,11 @@ def ddp_init_group(run_opts):
296455 -------
297456 None
298457 """
299- rank = os . environ . get ( "RANK" )
300- local_rank = os . environ . get ( "LOCAL_RANK" )
458+ rank = get_rank ( )
459+ local_rank = get_local_rank ( )
301460 if local_rank is None or rank is None :
302461 return
303462
304- local_rank = int (local_rank )
305463 if not run_opts ["distributed_backend" ] == "gloo" :
306464 if local_rank + 1 > torch .cuda .device_count ():
307465 raise ValueError (
0 commit comments