Skip to content

Commit a125f31

Browse files
authored
Add functions for running code once per node (#2992)
1 parent b7afd03 commit a125f31

3 files changed

Lines changed: 199 additions & 11 deletions

File tree

docs/multigpu.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,17 @@ MASTER=`echo $LISTNODES | cut -d" " -f1`
120120
torchrun --nproc_per_node=4 --nnodes=${SLURM_JOB_NUM_NODES} --node_rank=${SLURM_NODEID} --master_addr=${MASTER} --master_port=5555 train.py hparams/myrecipe.yaml
121121
```
122122

123+
#### Multi-node with separate filesystems
124+
125+
In addition to our `run_on_main` function, we have a parallel function `run_once_per_node` that runs on `LOCAL_RANK=0`, for setups where different nodes do not have access to the same filesystem so that checkpointing can be saved on all of the separate filesystems.
126+
127+
To apply this to checkpointing, we provide the convenience function:
128+
129+
`speechbrain.utils.checkpoints.convert_torch_save_hooks_to_once_per_node()`
130+
131+
If you call this, the saves should happen once on every node rather than only on a single process.
132+
133+
123134
## (DEPRECATED) Single-node multi-GPU training using Data Parallel
124135

125136
[**We strongly recommend AGAINST using `DataParallel`, even for single-node setups**](https://pytorch.org/docs/stable/generated/torch.nn.DataParallel.html)! Use `DistributedDataParallel` instead. We no longer provide support for `DataParallel`. Future PyTorch versions may even remove `DataParallel` altogether.

speechbrain/utils/checkpoints.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
ddp_broadcast,
7272
if_main_process,
7373
main_process_only,
74+
once_per_node,
7475
)
7576
from speechbrain.utils.logger import get_logger
7677

@@ -223,6 +224,15 @@ def torch_save(obj, path):
223224
torch.save(state_dict, path)
224225

225226

227+
@once_per_node
228+
def torch_save_once_per_node(obj, path):
229+
"""Copy of `torch_save` that is run once per node."""
230+
state_dict = obj.state_dict()
231+
if not state_dict:
232+
logger.warning(f"Saving an empty state_dict for {obj} in {path}.")
233+
torch.save(state_dict, path)
234+
235+
226236
def torch_parameter_transfer(obj, path):
227237
"""Non-strict Torch Module state_dict load.
228238
@@ -302,6 +312,15 @@ def _load_spm(obj, path):
302312
DEFAULT_LOAD_HOOKS[torch.optim.lr_scheduler.CyclicLR] = __wa._cycliclrloader
303313

304314

315+
def convert_torch_save_hooks_to_once_per_node():
316+
"""Update the save hooks to be run once per node. This should be called
317+
if you are running on more than one node with separate filesystems."""
318+
global DEFAULT_SAVE_HOOKS
319+
for obj, hook in DEFAULT_SAVE_HOOKS.items():
320+
if hook == torch_save:
321+
DEFAULT_SAVE_HOOKS[obj] = torch_save_once_per_node
322+
323+
305324
def mark_as_saver(method):
306325
"""Method decorator which marks given method as the checkpoint saving hook.
307326

speechbrain/utils/distributed.py

Lines changed: 169 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import torch
1616

1717
MAIN_PROC_ONLY: int = 0
18+
NODE_ONCE_ONLY: int = 0
1819

1920

2021
def rank_prefixed_message(message: str) -> str:
@@ -58,12 +59,29 @@ def get_rank() -> Optional[int]:
5859
return None
5960

6061

62+
def get_local_rank() -> Optional[int]:
63+
r"""Get the local rank of the current process on the current node.
64+
65+
Returns
66+
-------
67+
int or None
68+
The local rank of the current process, or None if the local rank could not be determined.
69+
"""
70+
rank_keys = ("LOCAL_RANK", "SLURM_LOCALID")
71+
for key in rank_keys:
72+
rank = os.environ.get(key)
73+
if rank is not None:
74+
return int(rank)
75+
# None to differentiate whether an environment variable was set at all
76+
return None
77+
78+
6179
def infer_device() -> str:
6280
"""Make a basic guess about intended running device based on
6381
availability and distributed environment variable 'LOCAL_RANK'"""
6482
if torch.cuda.is_available():
6583
device = "cuda"
66-
local_rank = os.environ.get("LOCAL_RANK")
84+
local_rank = get_local_rank()
6785
if local_rank is not None:
6886
device += f":{local_rank}"
6987
else:
@@ -136,6 +154,94 @@ def run_on_main(
136154
return result
137155

138156

157+
def run_once_per_node(
158+
func,
159+
args=None,
160+
kwargs=None,
161+
post_func=None,
162+
post_args=None,
163+
post_kwargs=None,
164+
run_post_on_all=False,
165+
):
166+
r"""Runs a function with DPP (multi-gpu) support.
167+
168+
The provided function `func` is only run once on each node, while other processes
169+
block to wait for the function execution to finish. This is useful for things such
170+
as saving a file to the disk on each separate node (i.e. the filesystems are separate).
171+
In addition, a second function can be specified to be run on other processes after the
172+
first function completes, for example, loading a file that was created on each node.
173+
174+
Arguments
175+
---------
176+
func : callable
177+
Function to be run once on each node.
178+
args : list, None
179+
Positional args to pass to func.
180+
kwargs : dict, None
181+
Keyword args to pass to func.
182+
post_func : callable, None
183+
Function to run after `func` has finished. By default, `post_func` is not run
184+
on the process that ran `func`.
185+
post_args : list, None
186+
Positional args to pass to post_func.
187+
post_kwargs : dict, None
188+
Keyword args to pass to post_func.
189+
run_post_on_all : bool
190+
Whether to run post_func on all processes, including the process that ran `func`.
191+
192+
Returns
193+
-------
194+
If `post_func` is provided, returns the result on all processes where `post_func` is run.
195+
If `run_post_on_all` is `False` or `post_func` is not provided, returns the result of `func` on the processes where it is run.
196+
If `post_func` is not provided, returns `None` on processes where `func` was not called.
197+
198+
Example
199+
-------
200+
>>> tmpfile = getfixture("tmpdir") / "example.pt"
201+
>>> # Return tensor so we don't have to load it on the saving process
202+
>>> def save_and_return(file, tensor):
203+
... torch.save(tensor, file)
204+
... return tensor
205+
>>> # Load tensor on non-saving processes
206+
>>> def load_tensor(file):
207+
... return torch.load(file)
208+
>>> # Save on node-primary processes, load on others
209+
>>> example_tensor = torch.ones(5)
210+
>>> loaded_tensor = run_once_per_node(
211+
... func=save_and_return,
212+
... args=[tmpfile, example_tensor],
213+
... post_func=load_tensor,
214+
... post_args=[tmpfile],
215+
... run_post_on_all=False,
216+
... )
217+
>>> # We should get the same result on all processes
218+
>>> loaded_tensor
219+
tensor([1., 1., 1., 1., 1.])
220+
"""
221+
# Handle the mutable data types' default args:
222+
args = args or []
223+
kwargs = kwargs or {}
224+
post_args = post_args or []
225+
post_kwargs = post_kwargs or {}
226+
227+
# Call the function exactly once per node, wait on other processes
228+
result = once_per_node(func)(*args, **kwargs)
229+
ddp_barrier()
230+
231+
# Call the post function if provided
232+
if post_func is not None:
233+
if run_post_on_all:
234+
# Just run on every process without any barrier.
235+
result = post_func(*post_args, **post_kwargs)
236+
else:
237+
# Do the opposite of `once_per_node` and await result
238+
if not is_local_rank_zero():
239+
result = post_func(*post_args, **post_kwargs)
240+
ddp_barrier()
241+
242+
return result
243+
244+
139245
def is_distributed_initialized() -> bool:
140246
r"Returns whether the current system is distributed."
141247
# `is_initialized` is only defined conditionally
@@ -148,10 +254,12 @@ def is_distributed_initialized() -> bool:
148254

149255
def if_main_process() -> bool:
150256
r"Returns whether the current process is the main process."
151-
if is_distributed_initialized():
152-
return torch.distributed.get_rank() == 0
153-
else:
154-
return True
257+
return not is_distributed_initialized() or get_rank() == 0
258+
259+
260+
def is_local_rank_zero() -> bool:
261+
r"Returns whether the current process has local rank of 0."
262+
return not is_distributed_initialized() or get_local_rank() == 0
155263

156264

157265
class MainProcessContext:
@@ -174,6 +282,26 @@ def __exit__(self, exc_type, exc_value, traceback):
174282
MAIN_PROC_ONLY -= 1
175283

176284

285+
class OncePerNodeContext:
286+
r"""
287+
Context manager to ensure code runs only once per node.
288+
This is useful to make sure that `NODE_ONCE_ONLY` global variable
289+
is decreased even if there's an exception raised inside of the
290+
`once_per_node_wrapped_fn` function.
291+
"""
292+
293+
def __enter__(self):
294+
r"""Enter the context. Increase the counter."""
295+
global NODE_ONCE_ONLY
296+
NODE_ONCE_ONLY += 1
297+
return self
298+
299+
def __exit__(self, exc_type, exc_value, traceback):
300+
r"""Exit the context. Decrease the counter."""
301+
global NODE_ONCE_ONLY
302+
NODE_ONCE_ONLY -= 1
303+
304+
177305
def main_process_only(function):
178306
r"""Function decorator to ensure the function runs only on the main process.
179307
This is useful for things like saving to the filesystem or logging
@@ -195,6 +323,37 @@ def main_proc_wrapped_func(*args, **kwargs):
195323
return main_proc_wrapped_func
196324

197325

326+
def once_per_node(function):
327+
r"""Function decorator to ensure the function runs only once per node.
328+
This is useful for things like saving to the filesystem
329+
where you only want it to happen on a single process on each node.
330+
331+
Unlike `main_process_only`, no broadcasting is done. Instead, processes
332+
with local_rank == 0 keep their own result, all other processes
333+
return None.
334+
"""
335+
336+
@wraps(function)
337+
def once_per_node_wrapped_fn(*args, **kwargs):
338+
"""This decorated function runs only if this is the main process."""
339+
with OncePerNodeContext():
340+
if is_local_rank_zero():
341+
return function(*args, **kwargs)
342+
else:
343+
return None
344+
345+
return once_per_node_wrapped_fn
346+
347+
348+
def ddp_prevent_block():
349+
r"Prevent blocking because only one or partial threads running."
350+
return (
351+
MAIN_PROC_ONLY >= 1
352+
or NODE_ONCE_ONLY >= 1
353+
or not is_distributed_initialized()
354+
)
355+
356+
198357
def ddp_barrier():
199358
r"""
200359
Synchronize all processes in distributed data parallel (DDP) mode.
@@ -216,7 +375,7 @@ def ddp_barrier():
216375
>>> print("hello world")
217376
hello world
218377
"""
219-
if MAIN_PROC_ONLY >= 1 or not is_distributed_initialized():
378+
if ddp_prevent_block():
220379
return
221380

222381
if torch.distributed.get_backend() == torch.distributed.Backend.NCCL:
@@ -241,7 +400,7 @@ def ddp_broadcast(communication_object, src=0):
241400
-------
242401
The communication_object passed on rank src.
243402
"""
244-
if MAIN_PROC_ONLY >= 1 or not is_distributed_initialized():
403+
if ddp_prevent_block():
245404
return communication_object
246405

247406
# Wrapping object in a list is required for preventing
@@ -271,7 +430,7 @@ def ddp_all_reduce(communication_object, reduce_op):
271430
"""
272431

273432
# If DDP not initialised or executed with a main process barrier
274-
if MAIN_PROC_ONLY >= 1 or not is_distributed_initialized():
433+
if ddp_prevent_block():
275434
return communication_object
276435

277436
torch.distributed.all_reduce(communication_object, op=reduce_op)
@@ -296,12 +455,11 @@ def ddp_init_group(run_opts):
296455
-------
297456
None
298457
"""
299-
rank = os.environ.get("RANK")
300-
local_rank = os.environ.get("LOCAL_RANK")
458+
rank = get_rank()
459+
local_rank = get_local_rank()
301460
if local_rank is None or rank is None:
302461
return
303462

304-
local_rank = int(local_rank)
305463
if not run_opts["distributed_backend"] == "gloo":
306464
if local_rank + 1 > torch.cuda.device_count():
307465
raise ValueError(

0 commit comments

Comments
 (0)