Add TransformerEngine to PT 2.0 training images#3315
Conversation
roywei
left a comment
There was a problem hiding this comment.
Let's also update the env var NCCL_ASYNC_ERROR_HANDLING=1 per customer request, this will make sure pytorch errors out properly out during distributed training.
|
/rerun |
| # Install flash attn and NVIDIA transformer engine | ||
| RUN MAX_JOBS=4 pip install flash-attn==2.0.4 --no-build-isolation | ||
| RUN pip install git+https://github.com/NVIDIA/TransformerEngine.git@release_v0.12 | ||
| ENV NCCL_ASYNC_ERROR_HANDLING=1 |
There was a problem hiding this comment.
this is already defined on line number 63
There was a problem hiding this comment.
Ack, added because of a different review comment, will remove
| pytorch_training, ec2_connection, region, gpu_only, ec2_instance_type, pt21_and_above_only | ||
| ): | ||
| """ | ||
| PT 2.1 reintroduces a dependency on CUDNN to support NVDA TransformerEngine. This test is to ensure that torch CUDNN matches system CUDNN in the container. |
There was a problem hiding this comment.
There is no PT 2.1 yet, this is an anticipatory test that we are adding to ensure that torch binaries are compiled with the same cudnn as exists in the container
| ).stdout.split()[-1] | ||
|
|
||
| cudnn_from_torch = ec2_connection.run( | ||
| f"nvidia-docker exec --user root {container_name} python -c 'from torch.backends import cudnn; print(cudnn.version())'", |
There was a problem hiding this comment.
this cudnn comes from pytorch and not from installed from OS package, right?
There was a problem hiding this comment.
This cudnn represents the cudnn version that torch is compiled with, not the DLC cudnn version. There are basically static links to cudnn from torch - while it doesn't appear to be a big issue if there are slightly different versions of cudnn from compile --> system, adding this test for future safety so that the versions don't go out of sync
GitHub Issue #, if available:
Note:
If merging this PR should also close the associated Issue, please also add that Issue # to the Linked Issues section on the right.
All PR's are checked weekly for staleness. This PR will be closed if not updated in 30 days.
Description
Tests run
By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 license. I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.