We provide a variety of examples for deep learning frameworks including PyTorch and JAX. Additionally, we offer Jupyter notebook tutorials and a selection of third-party examples. Please be aware that these third-party examples might need specific, older versions of dependencies to function properly.
- Accelerate Hugging Face Llama models with TE
- Provides code examples and explanations for integrating TE with the LLaMA2 and LLaMA2 models.
- PyTorch FSDP with FP8
- Distributed Training: How to set up and run distributed training using PyTorch’s FullyShardedDataParallel (FSDP) strategy.
- TE Integration: Instructions on integrating TE/FP8 with PyTorch for optimized performance.
- Checkpointing: Methods for applying activation checkpointing to manage memory usage during training.
- Attention backends in TE
- Attention Backends: Describes various attention backends supported by Transformer Engine, including framework-native, fused, and flash-attention backends, and their performance benefits.
- Flash vs. Non-Flash: Compares the flash algorithm with the standard non-flash algorithm, highlighting memory and computational efficiency improvements.
- Backend Selection: Details the logic for selecting the most appropriate backend based on availability and performance, and provides user control options for backend selection.
- Overlapping Communication with GEMM
- Training a TE module with GEMM and communication overlap, including various configurations and command-line arguments for customization.
- Performance Optimizations
- Multi-GPU Training: How to use TE with data, tensor, and sequence parallelism.
- Gradient Accumulation Fusion: Utilizing Tensor Cores to accumulate outputs directly into FP32 for better numerical accuracy.
- FP8 Weight Caching: Avoiding redundant FP8 casting during multiple gradient accumulation steps to improve efficiency.
- Introduction to FP8
- Overview of FP8 datatypes (E4M3, E5M2), mixed precision training, delayed scaling strategies, and code examples for FP8 configuration and usage.
- Basic MNIST Example
- Basic Transformer Encoder Example
- Single GPU Training: Demonstrates setting up and training a Transformer model using a single GPU.
- Data Parallelism: Scale training across multiple GPUs using data parallelism.
- Model Parallelism: Divide a model across multiple GPUs for parallel training.
- Multiprocessing with Model Parallelism: Multiprocessing for model parallelism, including multi-node support and hardware affinity setup.
- Basic MNIST Example
- TE JAX Integration Tutorial
- Introduction to integrating TE into an existing JAX model framework, building a Transformer Layer, and instructions on integrating TE modules like Linear and LayerNorm.
- Hugging Face Accelerate + TE
- Scripts for training with Accelerate and TE. Supports single GPU, and multi-GPU via DDP, FSDP, and DeepSpeed ZeRO 1-3.