[PJRT] Experimental support for torch.distributed and DDP on TPU v2/v3#4520
[PJRT] Experimental support for torch.distributed and DDP on TPU v2/v3#4520will-cromar merged 10 commits intomasterfrom
torch.distributed and DDP on TPU v2/v3#4520Conversation
|
@will-cromar is this one ready for review? |
8699230 to
81b14d1
Compare
|
I'll take another pass tomorrow to polish and add some comments, but this should be ready for review. |
| for step in range(steps): | ||
| # To make torch.randn produce same results across devices. | ||
| torch.manual_seed(2022 + step) | ||
| rng = torch.Generator().manual_seed(2022 + step) |
There was a problem hiding this comment.
Curios why we use torch.Generator() instead?
There was a problem hiding this comment.
The other option would be to wrap these randn calls in a lock and give them a common global seed, but explicitly creating a new generator with the same seed seems clearer to me. I would have done the same for module initialization, but that case doesn't support a custom RNG.
There was a problem hiding this comment.
One last comment. Maybe add a comment to the test case to suggest the reasoning behind?
| pjrt._run_multiprocess( | ||
| util.ddp_correctness, ddp=ddp, use_large_net=True, debug=FLAGS.debug) | ||
| util.ddp_correctness, | ||
| init_method='pjrt://', |
There was a problem hiding this comment.
Wonder if we want to parameterized the init_method with env as well?
There was a problem hiding this comment.
Good idea. Added another test that skips for TPU <= v3, since env:// doesn't work nicely with multithreading.
…/v3 (#4520) * Implement multithreaded XLA process group * Fix tests * Merge PJRT MNIST test * formatting * Clarify random generation in test_ddp.py * Mark some variables private * Remove some extra comments * Add test that uses env:// method * Explain local RNG * Explain --pjrt_distributed flag
ThreadLocalWorldto enable multithreadingtorch.distributedinit_methodthat uses PJRT runtime parameters and supports multithreadingtorch.distributed"rank" will become the same as our "ordinal", meaning we have one fewer set of indices to trackpjrt.DistributedDataParallelnow that the upstream version works on v3Performance comparison using ResNet50 with fake data on TPU v3:
Example usage:
Needs rebasing after #4504 merges
Follow-up: