diff --git a/dist-lab/01_basic_collectives.py b/dist-lab/01_basic_collectives.py new file mode 100644 index 0000000..7228ff5 --- /dev/null +++ b/dist-lab/01_basic_collectives.py @@ -0,0 +1,90 @@ +import os +import argparse +import torch +import torch.distributed as dist + + +def setup(backend: str): + dist.init_process_group(backend=backend) + rank = dist.get_rank() + world = dist.get_world_size() + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + + if backend == "nccl": + torch.cuda.set_device(local_rank) + device = torch.device("cuda", local_rank) + else: + device = torch.device("cpu") + + return rank, world, device + + +def log(rank, msg): + print(f"[rank {rank}] {msg}", flush=True) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--backend", choices=["gloo", "nccl"], default="gloo") + args = parser.parse_args() + + rank, world, device = setup(args.backend) + + dist.barrier() + log(rank, "barrier passed") + + # 1) broadcast: rank 0 sends one value to everybody + x = torch.tensor([111 if rank == 0 else -1], device=device) + dist.broadcast(x, src=0) + log(rank, f"broadcast -> {x.tolist()}") + + dist.barrier() + + # 2) all_reduce: everybody contributes, everybody gets the sum + x = torch.tensor([rank + 1], device=device) + dist.all_reduce(x, op=dist.ReduceOp.SUM) + log(rank, f"all_reduce SUM -> {x.tolist()}") + + dist.barrier() + + # 3) reduce: everybody contributes, only rank 0 gets the sum + x = torch.tensor([rank + 1], device=device) + dist.reduce(x, dst=0, op=dist.ReduceOp.SUM) + log(rank, f"reduce to rank 0 -> {x.tolist()}") + + dist.barrier() + + # 4) all_gather: everybody receives one tensor from each rank + x = torch.tensor([rank], device=device) + out = [torch.zeros_like(x) for _ in range(world)] + dist.all_gather(out, x) + log(rank, f"all_gather -> {[t.item() for t in out]}") + + dist.barrier() + + # 5) gather: only rank 0 receives from everybody + x = torch.tensor([rank * 10], device=device) + if rank == 0: + gather_list = [torch.zeros_like(x) for _ in range(world)] + else: + gather_list = None + dist.gather(x, gather_list=gather_list, dst=0) + if rank == 0: + log(rank, f"gather -> {[t.item() for t in gather_list]}") + + dist.barrier() + + # 6) scatter: rank 0 sends one tensor to each rank + y = torch.empty(1, device=device, dtype=torch.int64) + if rank == 0: + scatter_list = [torch.tensor([100 + r], device=device) for r in range(world)] + else: + scatter_list = None + dist.scatter(y, scatter_list=scatter_list, src=0) + log(rank, f"scatter <- {y.tolist()}") + + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/dist-lab/02_tensor_collectives.py b/dist-lab/02_tensor_collectives.py new file mode 100644 index 0000000..887524f --- /dev/null +++ b/dist-lab/02_tensor_collectives.py @@ -0,0 +1,91 @@ +import os +import argparse +import torch +import torch.distributed as dist + + +def setup(backend: str): + dist.init_process_group(backend=backend) + rank = dist.get_rank() + world = dist.get_world_size() + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + + if backend == "nccl": + torch.cuda.set_device(local_rank) + device = torch.device("cuda", local_rank) + else: + device = torch.device("cpu") + + return rank, world, device + + +def log(rank, msg): + print(f"[rank {rank}] {msg}", flush=True) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--backend", choices=["gloo", "nccl"], default="gloo") + args = parser.parse_args() + + rank, world, device = setup(args.backend) + + # 1) all_gather_into_tensor + x = torch.tensor([rank, rank + 10], device=device, dtype=torch.int64) + out = torch.empty(world * x.numel(), device=device, dtype=torch.int64) + dist.all_gather_into_tensor(out, x) + log(rank, f"all_gather_into_tensor -> {out.tolist()}") + + dist.barrier() + + # 2) reduce_scatter + # Every rank gives a list of length=world. + # Slot i across all ranks is reduced, rank i receives that result. + input_list = [ + torch.tensor([rank * 100 + i], device=device, dtype=torch.int64) + for i in range(world) + ] + out = torch.empty(1, device=device, dtype=torch.int64) + dist.reduce_scatter(out, input_list, op=dist.ReduceOp.SUM) + log(rank, f"reduce_scatter -> {out.tolist()}") + + dist.barrier() + + # 3) reduce_scatter_tensor + chunk = 2 + x = torch.arange(world * chunk, device=device, dtype=torch.int64) + rank * 1000 + out = torch.empty(chunk, device=device, dtype=torch.int64) + dist.reduce_scatter_tensor(out, x, op=dist.ReduceOp.SUM) + log(rank, f"reduce_scatter_tensor -> {out.tolist()}") + + dist.barrier() + + # 4) all_to_all_single + if args.backend == "nccl": + chunk = 2 + x = torch.arange(world * chunk, device=device, dtype=torch.int64) + rank * 10000 + out = torch.empty_like(x) + dist.all_to_all_single(out, x) + log(rank, f"all_to_all_single -> {out.tolist()}") + else: + log(rank, "skip all_to_all_single on gloo") + + dist.barrier() + + # 5) all_to_all + if args.backend == "nccl": + input_list = [ + torch.tensor([rank, i], device=device, dtype=torch.int64) + for i in range(world) + ] + output_list = [torch.empty(2, device=device, dtype=torch.int64) for _ in range(world)] + dist.all_to_all(output_list, input_list) + log(rank, f"all_to_all -> {[t.tolist() for t in output_list]}") + else: + log(rank, "skip all_to_all on gloo") + + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/dist-lab/03_objects_and_debug.py b/dist-lab/03_objects_and_debug.py new file mode 100644 index 0000000..8a7611e --- /dev/null +++ b/dist-lab/03_objects_and_debug.py @@ -0,0 +1,62 @@ +import argparse +import torch.distributed as dist + + +def log(rank, msg): + print(f"[rank {rank}] {msg}", flush=True) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--backend", choices=["gloo"], default="gloo") + args = parser.parse_args() + + dist.init_process_group(backend=args.backend) + rank = dist.get_rank() + world = dist.get_world_size() + + # 1) broadcast_object_list + objs = [{"from": 0, "msg": "hello"}] if rank == 0 else [None] + dist.broadcast_object_list(objs, src=0) + log(rank, f"broadcast_object_list -> {objs}") + + dist.barrier() + + # 2) all_gather_object + gathered = [None for _ in range(world)] + dist.all_gather_object(gathered, {"rank": rank, "square": rank * rank}) + log(rank, f"all_gather_object -> {gathered}") + + dist.barrier() + + # 3) gather_object + if rank == 0: + gathered_objs = [None for _ in range(world)] + else: + gathered_objs = None + dist.gather_object({"rank": rank, "cube": rank ** 3}, gathered_objs, dst=0) + if rank == 0: + log(rank, f"gather_object -> {gathered_objs}") + + dist.barrier() + + # 4) scatter_object_list + out_obj = [None] + if rank == 0: + in_objs = [f"obj_for_rank_{r}" for r in range(world)] + else: + in_objs = None + dist.scatter_object_list(out_obj, in_objs, src=0) + log(rank, f"scatter_object_list <- {out_obj}") + + dist.barrier() + + # 5) monitored_barrier + dist.monitored_barrier() + log(rank, "monitored_barrier passed") + + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/dist-lab/04_p2p.py b/dist-lab/04_p2p.py new file mode 100644 index 0000000..5eaded3 --- /dev/null +++ b/dist-lab/04_p2p.py @@ -0,0 +1,98 @@ +import os +import argparse +import torch +import torch.distributed as dist + + +def setup(backend: str): + dist.init_process_group(backend=backend) + rank = dist.get_rank() + world = dist.get_world_size() + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + + if backend == "nccl": + torch.cuda.set_device(local_rank) + device = torch.device("cuda", local_rank) + else: + device = torch.device("cpu") + + return rank, world, device + + +def log(rank, msg): + print(f"[rank {rank}] {msg}", flush=True) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--backend", choices=["gloo", "nccl"], default="gloo") + args = parser.parse_args() + + rank, world, device = setup(args.backend) + + if world % 2 != 0: + if rank == 0: + print("This demo expects an even world size.", flush=True) + dist.destroy_process_group() + return + + peer = rank + 1 if rank % 2 == 0 else rank - 1 + + # 0) sync so everybody starts together + dist.barrier() + + # 1) blocking send / recv + send_tensor = torch.tensor([rank], device=device, dtype=torch.int64) + recv_tensor = torch.empty(1, device=device, dtype=torch.int64) + + if rank % 2 == 0: + dist.send(send_tensor, dst=peer) + dist.recv(recv_tensor, src=peer) + else: + dist.recv(recv_tensor, src=peer) + dist.send(send_tensor, dst=peer) + + log(rank, f"send/recv with peer {peer} -> received {recv_tensor.tolist()}") + + dist.barrier() + + # 2) non-blocking isend / irecv + send_tensor = torch.tensor([rank + 100], device=device, dtype=torch.int64) + recv_tensor = torch.empty(1, device=device, dtype=torch.int64) + + if rank % 2 == 0: + req1 = dist.isend(send_tensor, dst=peer) + req2 = dist.irecv(recv_tensor, src=peer) + else: + req2 = dist.irecv(recv_tensor, src=peer) + req1 = dist.isend(send_tensor, dst=peer) + + req1.wait() + req2.wait() + + log(rank, f"isend/irecv with peer {peer} -> received {recv_tensor.tolist()}") + + dist.barrier() + + # 3) batch_isend_irecv in a ring: every rank sends to next and receives from prev + send_to = (rank + 1) % world + recv_from = (rank - 1 + world) % world + + send_tensor = torch.tensor([rank, rank + 1000], device=device, dtype=torch.int64) + recv_tensor = torch.empty(2, device=device, dtype=torch.int64) + + ops = [ + dist.P2POp(dist.isend, send_tensor, send_to), + dist.P2POp(dist.irecv, recv_tensor, recv_from), + ] + reqs = dist.batch_isend_irecv(ops) + for req in reqs: + req.wait() + + log(rank, f"batch_isend_irecv: recv from {recv_from} -> {recv_tensor.tolist()}") + + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/dist-lab/readme.md b/dist-lab/readme.md new file mode 100644 index 0000000..03b75c2 --- /dev/null +++ b/dist-lab/readme.md @@ -0,0 +1,159 @@ +# PyTorch Distributed Lab + +A hands-on lab covering PyTorch `torch.distributed` primitives across 4 files. + +## Quick Start (Fresh VM) + +Run this on **every node** to bootstrap the environment: + +```bash +# CPU-only (gloo backend) +bash setup.sh + +# With GPU/CUDA support (nccl backend) +bash setup.sh --gpu +``` + +Then activate the venv: + +```bash +source .venv/bin/activate +``` + +The script handles: system deps, uv installation, Python 3.11 venv, PyTorch, and useful extras (tensorboard, pynvml, psutil, py-spy, rich). + +## Setup + +- **2 nodes**, **2 processes per node**, **world size = 4** +- Replace `10.0.0.10` with your node0 IP +- Replace `eth0` with your actual network interface (`ip addr` to check) +- Put the same files on both nodes + +## Backend Guide + +| Backend | Device | Use for | +|---------|--------|---------| +| `gloo` | CPU | Files 01, 03, 04 — and most of 02 | +| `nccl` | GPU | File 02 when you want `all_to_all` / `all_to_all_single` | + +## Run Commands + +Run each command **on both nodes**. Only `--node-rank` differs (node0=implicit, node1=implicit — torchrun with rdzv handles it automatically). + +### 1) Basic Collectives (gloo/CPU) + +Covers: `barrier`, `broadcast`, `all_reduce`, `reduce`, `all_gather`, `gather`, `scatter` + +```bash +export GLOO_SOCKET_IFNAME=eth0 + +torchrun \ + --nnodes=2 \ + --nproc-per-node=2 \ + --rdzv-id=toy-dist-01 \ + --rdzv-backend=c10d \ + --rdzv-endpoint=10.0.0.10:29400 \ + 01_basic_collectives.py \ + --backend gloo +``` + +### 2) Point-to-Point (gloo/CPU) + +Covers: `send`, `recv`, `isend`, `irecv`, `batch_isend_irecv` + +```bash +export GLOO_SOCKET_IFNAME=eth0 + +torchrun \ + --nnodes=2 \ + --nproc-per-node=2 \ + --rdzv-id=toy-dist-01 \ + --rdzv-backend=c10d \ + --rdzv-endpoint=10.0.0.10:29400 \ + 04_p2p.py \ + --backend gloo +``` + +### 3) Objects & Debug (gloo only) + +Covers: `broadcast_object_list`, `all_gather_object`, `gather_object`, `scatter_object_list`, `monitored_barrier` + +```bash +export GLOO_SOCKET_IFNAME=eth0 + +torchrun \ + --nnodes=2 \ + --nproc-per-node=2 \ + --rdzv-id=toy-dist-01 \ + --rdzv-backend=c10d \ + --rdzv-endpoint=10.0.0.10:29400 \ + 03_objects_and_debug.py \ + --backend gloo +``` + +### 4) Tensor Collectives — gloo (skips all_to_all) + +Covers: `all_gather_into_tensor`, `reduce_scatter`, `reduce_scatter_tensor` + +```bash +export GLOO_SOCKET_IFNAME=eth0 + +torchrun \ + --nnodes=2 \ + --nproc-per-node=2 \ + --rdzv-id=toy-dist-01 \ + --rdzv-backend=c10d \ + --rdzv-endpoint=10.0.0.10:29400 \ + 02_tensor_collectives.py \ + --backend gloo +``` + +### 5) Tensor Collectives — nccl (includes all_to_all) + +Covers: everything in step 4 **plus** `all_to_all_single`, `all_to_all` + +```bash +export NCCL_SOCKET_IFNAME=eth0 + +torchrun \ + --nnodes=2 \ + --nproc-per-node=2 \ + --rdzv-id=toy-dist-01 \ + --rdzv-backend=c10d \ + --rdzv-endpoint=10.0.0.10:29400 \ + 02_tensor_collectives.py \ + --backend nccl +``` + +## Recommended Order + +| Step | File | Backend | Why | +|------|------|---------|-----| +| 1 | `01_basic_collectives.py` | gloo | "Who talks to everyone" — the fundamentals | +| 2 | `04_p2p.py` | gloo | "Who talks to one peer" — send/recv patterns | +| 3 | `03_objects_and_debug.py` | gloo | Object collectives + `monitored_barrier` for debugging | +| 4 | `02_tensor_collectives.py` | gloo | Shape-sensitive tensor collectives (skips all_to_all) | +| 5 | `02_tensor_collectives.py` | nccl | Re-run with GPUs to unlock `all_to_all*` | + +## Single-Node Quick Test + +To test locally on one machine with no multi-node setup: + +```bash +torchrun \ + --nnodes=1 \ + --nproc-per-node=4 \ + --rdzv-id=local-test \ + --rdzv-backend=c10d \ + --rdzv-endpoint=127.0.0.1:29400 \ + 01_basic_collectives.py \ + --backend gloo +``` + +## Notes + +- `monitored_barrier()` is Gloo-only — useful for debugging desynchronized ranks +- `all_to_all` / `all_to_all_single` require NCCL (not supported on Gloo) +- Object collectives use pickle under the hood — learning only, not for production perf +- If interface auto-detection fails, set `GLOO_SOCKET_IFNAME` or `NCCL_SOCKET_IFNAME` +- The rendezvous endpoint default port is 29400 if not specified diff --git a/dist-lab/setup.sh b/dist-lab/setup.sh new file mode 100755 index 0000000..79f5556 --- /dev/null +++ b/dist-lab/setup.sh @@ -0,0 +1,121 @@ +#!/usr/bin/env bash +set -euo pipefail + +# ───────────────────────────────────────────────────────────────────── +# dist-lab/setup.sh +# Bootstrap a fresh VM for running the PyTorch distributed lab. +# Assumes: non-root user with sudo, Debian/Ubuntu-based OS. +# Usage: bash setup.sh [--gpu] +# --gpu Install CUDA-enabled PyTorch + NCCL (default: CPU-only) +# ───────────────────────────────────────────────────────────────────── + +GPU=false +for arg in "$@"; do + case "$arg" in + --gpu) GPU=true ;; + *) echo "Unknown arg: $arg"; exit 1 ;; + esac +done + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +VENV_DIR="$SCRIPT_DIR/.venv" + +# ── 1. System packages ────────────────────────────────────────────── +echo ">>> Installing system dependencies..." +# DPkg::Lock::Timeout makes apt wait (up to 5 min) for the lock instead +# of failing immediately — avoids races with unattended-upgrades on boot. +sudo apt-get update -qq -o DPkg::Lock::Timeout=300 +sudo apt-get install -y -qq -o DPkg::Lock::Timeout=300 \ + build-essential \ + curl \ + git \ + net-tools \ + iproute2 \ + htop \ + tmux \ + python3-dev + +# ── 2. Install uv ─────────────────────────────────────────────────── +if ! command -v uv &>/dev/null; then + echo ">>> Installing uv..." + mkdir -p "$HOME/.config" "$HOME/.local/bin" + # Ensure ~/.config is owned by the current user (it may have been + # created by a root-level process on first boot). + if [ "$(stat -c '%U' "$HOME/.config")" != "$(whoami)" ]; then + sudo chown -R "$(whoami)":"$(id -gn)" "$HOME/.config" + fi + curl -LsSf https://astral.sh/uv/install.sh | INSTALLER_NO_MODIFY_PATH=1 sh + # Make uv available in current shell + export PATH="$HOME/.local/bin:$PATH" +else + echo ">>> uv already installed: $(uv --version)" +fi + +# ── 3. Create venv ────────────────────────────────────────────────── +echo ">>> Creating virtual environment at $VENV_DIR ..." +uv venv "$VENV_DIR" --python 3.11 + +# ── 4. Install Python packages ────────────────────────────────────── +echo ">>> Installing Python packages..." + +if [ "$GPU" = true ]; then + echo " (GPU mode — installing CUDA-enabled PyTorch)" + uv pip install --python "$VENV_DIR/bin/python" \ + torch torchvision torchaudio \ + --index-url https://download.pytorch.org/whl/cu124 +else + echo " (CPU mode — installing CPU-only PyTorch)" + uv pip install --python "$VENV_DIR/bin/python" \ + torch torchvision torchaudio \ + --index-url https://download.pytorch.org/whl/cpu +fi + +# Useful extras for distributed work +uv pip install --python "$VENV_DIR/bin/python" \ + tensorboard \ + torch-tb-profiler \ + pynvml \ + psutil \ + py-spy \ + rich + +# ── 5. Verify install ─────────────────────────────────────────────── +echo "" +echo ">>> Verifying installation..." +"$VENV_DIR/bin/python" -c " +import torch +print(f' PyTorch : {torch.__version__}') +print(f' CUDA available: {torch.cuda.is_available()}') +if torch.cuda.is_available(): + print(f' GPU count : {torch.cuda.device_count()}') + for i in range(torch.cuda.device_count()): + print(f' [{i}] {torch.cuda.get_device_name(i)}') +import torch.distributed as dist +print(f' Gloo available: {dist.is_gloo_available()}') +print(f' NCCL available: {dist.is_nccl_available()}') +" + +# ── 6. Print activation instructions ──────────────────────────────── +echo "" +echo "=========================================" +echo " Setup complete!" +echo "=========================================" +echo "" +echo "Activate the venv:" +echo " source $VENV_DIR/bin/activate" +echo "" +echo "Then run (example single-node, 4 processes):" +echo " torchrun --nnodes=1 --nproc-per-node=4 \\" +echo " --rdzv-id=local --rdzv-backend=c10d \\" +echo " --rdzv-endpoint=127.0.0.1:29400 \\" +echo " 01_basic_collectives.py --backend gloo" +echo "" +if [ "$GPU" = true ]; then + echo "GPU mode was selected. For NCCL runs, set:" + echo " export NCCL_SOCKET_IFNAME=" +else + echo "CPU mode was selected. Re-run with --gpu for CUDA/NCCL support." +fi +echo "" + +#sudo apt-get update -qq && sudo apt-get install -y -qq git && cd /home/$(whoami) && git clone https://github.com/gpu-poor/multi-node-gpu.git && cd multi-node-gpu && git checkout dist-lab-exercises && cd dist-lab && bash setup.sh \ No newline at end of file